""" Model implemetnation for V2SNet"""
# pylint:disable=invalid-name, missing-function-docstring
# pylint:disable=missing-class-docstring, redefined-outer-name
# pylint:disable=unused-variable, super-with-arguments, line-too-long
# pylint:disable=attribute-defined-outside-init, too-many-instance-attributes
# pylint:disable=access-member-before-definition, abstract-method
# pylint:disable=consider-using-from-import,consider-using-f-string
import time
import torchvision
import numpy
import torch
import torch.nn as nn
[docs]class ConstPaddedConv(nn.Module):
def __init__(
self,
cIn,
cOut,
kernel_size,
stride=1,
dilation=1,
padding=0):
super(ConstPaddedConv, self).__init__()
#self.padder = nn.ConstantPad3d( padding, -10 )
self.conv = nn.Conv3d(
cIn,
cOut,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=0)
self.padding = padding
[docs] def forward(self, data):
if self.padding > 0:
padding = [self.padding] * 6
data = nn.functional.pad(data, padding, mode="replicate")
return self.conv(data)
[docs]class MultiDilationUnit(nn.Module):
def __init__(
self,
channelsIn,
channelsDilated,
channelsOut,
baseKernelSize=7):
super(MultiDilationUnit, self).__init__()
self.preConv = ConstPaddedConv(
channelsIn, channelsDilated, kernel_size=3, padding=1)
self.dilationModules = torch.nn.ModuleList()
dilations = [1, baseKernelSize - 2]
numChannels = 0
for dilation in dilations:
for kernelSize in [baseKernelSize]:
p = int(((kernelSize - 1) / 2) * dilation)
conv = ConstPaddedConv(
channelsDilated,
channelsDilated,
kernel_size=kernelSize,
dilation=dilation,
padding=p)
self.dilationModules.append(conv)
numChannels += channelsDilated
self.combineConv = ConstPaddedConv(
numChannels, channelsOut, kernel_size=1, padding=0)
self.nonLin = nn.Softsign()
[docs] def forward(self, data):
data = self.nonLin(self.preConv(data))
d = []
for i, conv in enumerate(self.dilationModules):
dilation = conv(data)
d.append(dilation)
#d.append( self.nonLin( dilation ) )
dilations = torch.cat(d, dim=1)
res = self.nonLin(self.combineConv(dilations))
return res
[docs]class Model(nn.Module):
def __init__(self, mask=True):
super(Model, self).__init__()
self.mask = mask
self.conv64_0 = ConstPaddedConv(
2, 16, kernel_size=3, stride=1, padding=1, dilation=1)
self.conv64_1 = ConstPaddedConv(
16, 24, kernel_size=3, stride=1, padding=1, dilation=1)
self.conv64_2 = ConstPaddedConv(
24, 24, kernel_size=3, stride=1, padding=1, dilation=1)
self.down64to32 = ConstPaddedConv(
24, 32, kernel_size=4, stride=2, padding=1)
self.conv32_0 = ConstPaddedConv(
32, 32, kernel_size=3, stride=1, padding=1)
self.conv32_1 = ConstPaddedConv(
32, 32, kernel_size=3, stride=1, padding=1)
self.conv32_2 = ConstPaddedConv(
32, 32, kernel_size=3, stride=1, padding=1)
self.down32to16 = ConstPaddedConv(
32, 64, kernel_size=4, stride=2, padding=1)
self.conv16_0 = ConstPaddedConv(
64, 64, kernel_size=3, stride=1, padding=1)
self.conv16_1 = ConstPaddedConv(
64, 64, kernel_size=3, stride=1, padding=1)
self.conv16_2 = ConstPaddedConv(
64, 64, kernel_size=3, stride=1, padding=1)
self.down16to8 = ConstPaddedConv(
64, 96, kernel_size=4, stride=2, padding=1)
self.conv8_0 = ConstPaddedConv(
96, 96, kernel_size=3, stride=1, padding=1)
self.conv8_1 = ConstPaddedConv(
96, 96, kernel_size=3, stride=1, padding=1)
self.conv8_2 = ConstPaddedConv(
96, 96, kernel_size=3, stride=1, padding=1)
self.conv8_3 = ConstPaddedConv(
96, 96, kernel_size=3, stride=1, padding=1)
self.conv8_4 = ConstPaddedConv(
96, 96, kernel_size=3, stride=1, padding=1)
self.conv8_5 = ConstPaddedConv(
96, 96, kernel_size=3, stride=1, padding=1)
self.conv16_combine_0 = ConstPaddedConv(
96 + 64, 64, kernel_size=3, stride=1, padding=1)
self.conv16_combine_1 = ConstPaddedConv(
64, 64, kernel_size=3, stride=1, padding=1)
self.conv32_combine_0 = ConstPaddedConv(
64 + 32, 32, kernel_size=3, stride=1, padding=1)
self.conv32_combine_1 = ConstPaddedConv(
32, 32, kernel_size=3, stride=1, padding=1)
self.conv64_combine_0 = ConstPaddedConv(
32 + 24, 24, kernel_size=3, stride=1, padding=1)
self.conv64_combine_1 = ConstPaddedConv(
24, 16, kernel_size=3, stride=1, padding=1)
self.conv64_combine_2 = ConstPaddedConv(
16, 3, kernel_size=3, stride=1, padding=1)
self.conv32toOutput = ConstPaddedConv(32, 3, kernel_size=1)
self.conv16toOutput = ConstPaddedConv(64, 3, kernel_size=1)
self.conv8toOutput = ConstPaddedConv(96, 3, kernel_size=1)
self.nonLin = nn.Softsign()
self.timing = False
[docs] def time(self, label=None):
if self.timing:
torch.cuda.synchronize()
if label is not None:
print(label, time.perf_counter() - self.dt)
self.dt = time.perf_counter()
[docs] def forward(self, preoperative, intraoperative):
config = torch.cat((preoperative, intraoperative), 1)
self.time()
res64 = self.nonLin(self.conv64_0(config))
res64 = self.nonLin(self.conv64_1(res64))
res64 = self.nonLin(self.conv64_2(res64))
self.time(0)
res32 = self.nonLin(self.down64to32(res64))
self.time(1)
res32 = self.nonLin(self.conv32_0(res32))
res32 = self.nonLin(self.conv32_1(res32))
res32 = self.nonLin(self.conv32_2(res32))
self.time(2)
res16 = self.nonLin(self.down32to16(res32))
self.time(3)
res16 = self.nonLin(self.conv16_0(res16))
res16 = self.nonLin(self.conv16_1(res16))
res16 = self.nonLin(self.conv16_2(res16))
self.time(4)
res8 = self.nonLin(self.down16to8(res16))
self.time(5)
res8 = self.nonLin(self.conv8_0(res8))
res8 = self.nonLin(self.conv8_1(res8))
res8 = self.nonLin(self.conv8_2(res8))
res8 = self.nonLin(self.conv8_3(res8))
res8 = self.nonLin(self.conv8_4(res8))
res8 = self.nonLin(self.conv8_5(res8))
self.time(6)
up16 = torch.nn.functional.interpolate(res8, size=16)
self.time(7)
res16 = torch.cat((up16, res16), dim=1)
res16 = self.nonLin(self.conv16_combine_0(res16))
res16 = self.nonLin(self.conv16_combine_1(res16))
self.time(8)
up32 = torch.nn.functional.interpolate(res16, size=32)
self.time(9)
res32 = torch.cat((up32, res32), dim=1)
res32 = self.nonLin(self.conv32_combine_0(res32))
res32 = self.nonLin(self.conv32_combine_1(res32))
self.time(10)
up64 = torch.nn.functional.interpolate(res32, size=64)
self.time(11)
res64 = torch.cat((up64, res64), dim=1)
res64 = self.nonLin(self.conv64_combine_0(res64))
res64 = self.nonLin(self.conv64_combine_1(res64))
res64 = self.nonLin(self.conv64_combine_2(res64))
self.time(12)
res64out = res64
self.time(13)
# Lower resolution outputs for additional error terms:
res32out = self.nonLin(self.conv32toOutput(res32))
res16out = self.nonLin(self.conv16toOutput(res16))
res8out = self.nonLin(self.conv8toOutput(res8))
self.time(14)
# Generate mask from signed distance function:
mask = preoperative.lt(0)
mask = mask.expand(-1, 3, -1, -1, -1).float()
if self.mask:
res64out = res64out * mask
res32out = res32out * \
torch.nn.functional.interpolate(mask, size=32)
res16out = res16out * \
torch.nn.functional.interpolate(mask, size=16)
res8out = res8out * torch.nn.functional.interpolate(mask, size=8)
self.time(15)
return res64out, res32out, res16out, res8out
[docs] def init(self):
self.apply(initWeightsAverage)
[docs]class TestModel(nn.Module):
def __init__(self, mask=True):
super(TestModel, self).__init__()
self.mask = mask
self.conv0 = ConstPaddedConv(2, 1, kernel_size=3, stride=1, padding=1)
self.convs = nn.ModuleList()
for i in range(100):
self.convs.append(
ConstPaddedConv(
1,
1,
kernel_size=3,
stride=1,
padding=1))
self.convX = ConstPaddedConv(1, 3, kernel_size=3, stride=1, padding=1)
self.nonLin = nn.Softsign()
self.timing = False
[docs] def time(self, label=None):
if self.timing:
torch.cuda.synchronize()
if label is not None:
print(label, time.perf_counter() - self.dt)
self.dt = time.perf_counter()
[docs] def forward(self, preoperative, intraoperative):
config = torch.cat((preoperative, intraoperative), 1)
res64 = self.nonLin(self.conv0(config))
for c in self.convs:
res64 = self.nonLin(c(res64))
res64out = self.nonLin(self.convX(res64))
res32out = torch.nn.functional.interpolate(res64out, size=32)
res16out = torch.nn.functional.interpolate(res64out, size=16)
res8out = torch.nn.functional.interpolate(res64out, size=8)
self.time(15)
return res64out, res32out, res16out, res8out
# Initialize to do an "average" operation only:
[docs]def initWeightsAverage(m):
if isinstance(m, nn.Conv3d):
m.bias.data = torch.zeros_like(m.bias)
# Calculate number of input channels times kernel size:
s = m.weight.shape
numInputs = s[1] * s[2] * s[3] * s[4]
m.weight.data = torch.ones_like(m.weight) / numInputs
m.weight.data += torch.randn_like(m.weight) / numInputs
[docs]def initWeights(m):
if isinstance(m, nn.Conv3d):
torch.nn.init.xavier_normal_(m.weight.data)
m.bias.data.zero_()
# Perform an occlusion experiment to see which inputs influence a final
# output pixel:
[docs]def occlusionExperiment(
model,
preoperative,
intraoperative,
occ_size=4,
occ_stride=4,
occ_pixel=1):
# get the width and height of the image
width, height = preoperative.shape[-2], preoperative.shape[-1]
# setting the output image width and height
output_height = int(numpy.ceil((height - occ_size) / occ_stride))
output_width = int(numpy.ceil((width - occ_size) / occ_stride))
# create a white image of sizes we defined
heatmap = torch.zeros((output_height, output_width))
# Initialize to do an "average" operation only:
for name, p in model.named_parameters():
if "bias" in name:
p.data = torch.zeros_like(p)
else:
p.data = torch.ones_like(
p) / (p.shape[1] * p.shape[-3] * p.shape[-2] * p.shape[-1])
p.requires_grad = True
model.init()
baseOutput, _, _, _ = model(preoperative, intraoperative)
print(
"baseOutput:",
torch.min(
torch.abs(baseOutput)).item(),
torch.max(baseOutput).item())
with torch.no_grad():
# iterate all the pixels in each column
for h in range(0, height):
print(str(h) + "/" + str(height))
for w in range(0, width):
h_start = h * occ_stride
w_start = w * occ_stride
h_end = min(height, h_start + occ_size)
w_end = min(width, w_start + occ_size)
if (w_end) >= width or (h_end) >= height:
continue
preop = preoperative.clone().detach()
# replacing all the pixel information in the image with
# occ_pixel(grey) in the specified location
preop[:, :, :, w_start:w_end, h_start:h_end] = occ_pixel
# run inference on modified image
output, _, _, _ = model(preop, intraoperative)
val = output[0, :, 0, 0, 0].norm().item()
heatmap[h, w] = val
print(val)
del preop, output
print("heatmap", torch.min(heatmap), torch.max(heatmap))
heatmap = heatmap - torch.min(heatmap)
print("heatmap", torch.min(heatmap), torch.max(heatmap))
heatmap = heatmap + 1e-32
print("heatmap", torch.min(heatmap), torch.max(heatmap))
heatmap_log = torch.log(heatmap)
print("heatmap_log", torch.min(heatmap_log), torch.max(heatmap_log))
heatmap = (heatmap - torch.min(heatmap)) / \
(torch.max(heatmap) - torch.min(heatmap))
print("heatmap", torch.min(heatmap), torch.max(heatmap))
heatmap_log = (heatmap_log - torch.min(heatmap_log)) / \
(torch.max(heatmap_log) - torch.min(heatmap_log))
torchvision.utils.save_image(
heatmap,
"occlusionExperiment.png",
normalize=True)
torchvision.utils.save_image(
heatmap_log,
"occlusionExperiment_log.png",
normalize=True)
[docs]def timingExperiment(model):
st = time.time()
numRuns = 250
print("Timing test. Running {} samples...".format(numRuns))
for i in range(numRuns):
preoperative = torch.randn((1, 1, 64, 64, 64)).cuda()
intraoperative = torch.randn((1, 1, 64, 64, 64)).cuda()
preoperative.requires_grad = True
intraoperative.requires_grad = True
modelOutput, _, _, _ = model(preoperative, intraoperative)
torch.cuda.synchronize()
modelOutput.data.zero_()
dt = time.time() - st
print("\tTime:", dt)
print(
"\tTime per sample: {}, ({:2.2f} FPS)".format(
dt / numRuns,
numRuns / dt))