"""
Definition of the HSMNet model structure, and various helper functions.
"""
# pylint:disable=invalid-name, line-too-long, missing-docstring, too-many-locals
# pylint:disable=too-many-instance-attributes, no-else-return, no-self-use
# pylint:disable=super-with-arguments, abstract-method,
# pylint:disable=consider-using-from-import
from __future__ import print_function
import math
import torch
import torch.nn as nn
import torch.utils.data
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
[docs]class HSMNet_model(nn.Module):
def __init__(self, maxdisp, clean, device, level=1):
super(HSMNet_model, self).__init__()
self.maxdisp = maxdisp
self.clean = clean
self.feature_extraction = unet()
self.level = level
self.device = device
# block 4
self.decoder6 = decoderBlock(6, 32, 32, up=True, pool=True)
if self.level > 2:
self.decoder5 = decoderBlock(6, 32, 32, up=False, pool=True)
else:
self.decoder5 = decoderBlock(6, 32, 32, up=True, pool=True)
if self.level > 1:
self.decoder4 = decoderBlock(6, 32, 32, up=False)
else:
self.decoder4 = decoderBlock(6, 32, 32, up=True)
self.decoder3 = decoderBlock(
5, 32, 32, stride=(
2, 1, 1), up=False, nstride=1)
# reg
self.disp_reg8 = disparityregression(self.maxdisp, 16)
self.disp_reg16 = disparityregression(self.maxdisp, 16)
self.disp_reg32 = disparityregression(self.maxdisp, 32)
self.disp_reg64 = disparityregression(self.maxdisp, 64)
[docs] def feature_vol(self, refimg_fea, targetimg_fea, maxdisp, leftview=True):
'''
diff feature volume
'''
width = refimg_fea.shape[-1]
# pylint:disable=no-member
cost = Variable(
torch.FloatTensor(
refimg_fea.size()[0],
refimg_fea.size()[1],
maxdisp,
refimg_fea.size()[2],
refimg_fea.size()[3]).fill_(0.))
if self.device.type.startswith("cuda"):
cost = cost.to(self.device)
for i in range(maxdisp):
feata = refimg_fea[:, :, :, i:width]
featb = targetimg_fea[:, :, :, :width - i]
# pylint:disable=unsupported-assignment-operation
# concat
if leftview:
cost[:, :refimg_fea.size()[1], i, :, i:] = torch.abs(
feata - featb)
else:
cost[:, :refimg_fea.size()[1], i, :, :width -
i] = torch.abs(featb - feata)
# pylint:disable=no-member
cost = cost.contiguous()
return cost
[docs] def forward(self, left, right):
nsample = left.shape[0]
conv4, conv3, conv2, conv1 = self.feature_extraction(
torch.cat([left, right], 0))
conv40, conv30, conv20, conv10 = conv4[:nsample], conv3[:nsample], conv2[:nsample], conv1[:nsample]
conv41, conv31, conv21, conv11 = conv4[nsample:], conv3[nsample:], conv2[nsample:], conv1[nsample:]
feat6 = self.feature_vol(conv40, conv41, self.maxdisp // 64)
feat5 = self.feature_vol(conv30, conv31, self.maxdisp // 32)
feat4 = self.feature_vol(conv20, conv21, self.maxdisp // 16)
feat3 = self.feature_vol(conv10, conv11, self.maxdisp // 8)
feat6_2x, cost6 = self.decoder6(feat6)
feat5 = torch.cat((feat6_2x, feat5), dim=1)
feat5_2x, cost5 = self.decoder5(feat5)
if self.level > 2:
cost3 = F.upsample(
cost5, [
left.size()[2], left.size()[3]], mode='bilinear')
else:
feat4 = torch.cat((feat5_2x, feat4), dim=1)
feat4_2x, cost4 = self.decoder4(feat4) # 32
if self.level > 1:
cost3 = F.upsample((cost4).unsqueeze(1),
[self.disp_reg8.disp.shape[1],
left.size()[2],
left.size()[3]],
mode='trilinear').squeeze(1)
else:
feat3 = torch.cat((feat4_2x, feat3), dim=1)
_, cost3 = self.decoder3(feat3) # 32
cost3 = F.upsample(
cost3, [
left.size()[2], left.size()[3]], mode='bilinear')
if self.level > 2:
final_reg = self.disp_reg32
else:
final_reg = self.disp_reg8
if self.training or self.clean == -1:
pred3 = final_reg(F.softmax(cost3, 1))
entropy = pred3 # to save memory
else:
pred3, entropy = final_reg(F.softmax(cost3, 1), ifent=True)
pred3[entropy > self.clean] = np.inf
if self.training:
cost6 = F.upsample((cost6).unsqueeze(1),
[self.disp_reg8.disp.shape[1],
left.size()[2],
left.size()[3]],
mode='trilinear').squeeze(1)
cost5 = F.upsample((cost5).unsqueeze(1),
[self.disp_reg8.disp.shape[1],
left.size()[2],
left.size()[3]],
mode='trilinear').squeeze(1)
cost4 = F.upsample(
cost4, [
left.size()[2], left.size()[3]], mode='bilinear')
pred6 = self.disp_reg16(F.softmax(cost6, 1))
pred5 = self.disp_reg16(F.softmax(cost5, 1))
pred4 = self.disp_reg16(F.softmax(cost4, 1))
stacked = [pred3, pred4, pred5, pred6]
return stacked, entropy
else:
return pred3, torch.squeeze(entropy)
[docs]class sepConv3dBlock(nn.Module):
'''
Separable 3d convolution block as 2 separable convolutions and a projection
layer
'''
def __init__(self, in_planes, out_planes, stride=(1, 1, 1)):
super(sepConv3dBlock, self).__init__()
if in_planes == out_planes and stride == (1, 1, 1):
self.downsample = None
else:
self.downsample = projfeat3d(in_planes, out_planes, stride)
self.conv1 = sepConv3d(in_planes, out_planes, 3, stride, 1)
self.conv2 = sepConv3d(out_planes, out_planes, 3, (1, 1, 1), 1)
[docs] def forward(self, x):
out = F.relu(self.conv1(x), inplace=True)
if self.downsample:
x = self.downsample(x)
out = F.relu(x + self.conv2(out), inplace=True)
return out
[docs]class projfeat3d(nn.Module):
'''
Turn 3d projection into 2d projection
'''
def __init__(self, in_planes, out_planes, stride):
super(projfeat3d, self).__init__()
self.stride = stride
self.conv1 = nn.Conv2d(in_planes, out_planes, (1, 1), padding=(
0, 0), stride=stride[:2], bias=False)
self.bn = nn.BatchNorm2d(out_planes)
[docs] def forward(self, x):
b, c, d, h, w = x.size()
x = self.conv1(x.view(b, c, d, h * w))
x = self.bn(x)
x = x.view(b, -1, d // self.stride[0], h, w)
return x
# original conv3d block
[docs]def sepConv3d(in_planes, out_planes, kernel_size, stride, pad, bias=False):
if bias:
return nn.Sequential(
nn.Conv3d(
in_planes,
out_planes,
kernel_size=kernel_size,
padding=pad,
stride=stride,
bias=bias))
else:
return nn.Sequential(
nn.Conv3d(
in_planes,
out_planes,
kernel_size=kernel_size,
padding=pad,
stride=stride,
bias=bias),
nn.BatchNorm3d(out_planes))
[docs]class disparityregression(nn.Module):
def __init__(self, maxdisp, divisor):
super(disparityregression, self).__init__()
maxdisp = int(maxdisp / divisor)
self.register_buffer(
'disp', torch.Tensor(
np.reshape(
np.array(
range(maxdisp)), [
1, maxdisp, 1, 1])))
self.divisor = divisor
[docs] def forward(self, x, ifent=False):
disp = self.disp.repeat(x.size()[0], 1, x.size()[2], x.size()[3])
out = torch.sum(x * disp, 1) * self.divisor
if ifent:
# entropy
x = x + 1e-12
ent = (-x * x.log()).sum(dim=1)
return out, ent
else:
return out
[docs]class decoderBlock(nn.Module):
def __init__(
self,
nconvs,
inchannelF,
channelF,
stride=(
1,
1,
1),
up=False,
nstride=1,
pool=False):
super(decoderBlock, self).__init__()
self.pool = pool
stride = [stride] * nstride + [(1, 1, 1)] * (nconvs - nstride)
self.convs = [sepConv3dBlock(inchannelF, channelF, stride=stride[0])]
for i in range(1, nconvs):
self.convs.append(
sepConv3dBlock(
channelF,
channelF,
stride=stride[i]))
self.convs = nn.Sequential(*self.convs)
self.classify = nn.Sequential(
sepConv3d(
channelF, channelF, 3, (1, 1, 1), 1), nn.ReLU(
inplace=True), sepConv3d(
channelF, 1, 3, (1, 1, 1), 1, bias=True))
self.up = False
if up:
self.up = True
self.up = nn.Sequential(
nn.Upsample(
scale_factor=(
2,
2,
2),
mode='trilinear'),
sepConv3d(
channelF,
channelF // 2,
3,
(1,
1,
1),
1,
bias=False),
nn.ReLU(
inplace=True))
if pool:
self.pool_convs = torch.nn.ModuleList(
[
sepConv3d(
channelF, channelF, 1, (1, 1, 1), 0), sepConv3d(
channelF, channelF, 1, (1, 1, 1), 0), sepConv3d(
channelF, channelF, 1, (1, 1, 1), 0), sepConv3d(
channelF, channelF, 1, (1, 1, 1), 0)])
for m in self.modules():
if isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * \
m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if hasattr(m.bias, 'data'):
m.bias.data.zero_()
[docs] def forward(self, fvl):
# left
fvl = self.convs(fvl)
# pooling
if self.pool:
fvl_out = fvl
_, _, d, h, w = fvl.shape
for i, pool_size in enumerate(np.linspace(
1, min(d, h, w) // 2, 4, dtype=int)):
kernel_size = (int(d / pool_size),
int(h / pool_size), int(w / pool_size))
out = F.avg_pool3d(fvl, kernel_size, stride=kernel_size)
out = self.pool_convs[i](out)
out = F.upsample(out, size=(d, h, w), mode='trilinear')
fvl_out = fvl_out + 0.25 * out
fvl = F.relu(fvl_out / 2., inplace=True)
if self.training:
# classification
costl = self.classify(fvl)
if self.up:
fvl = self.up(fvl)
else:
# classification
if self.up:
fvl = self.up(fvl)
costl = fvl
else:
costl = self.classify(fvl)
return fvl, costl.squeeze(1)
[docs]class unet(nn.Module):
def __init__(self):
super(unet, self).__init__()
self.inplanes = 32
# Encoder
self.convbnrelu1_1 = conv2DBatchNormRelu(
in_channels=3, k_size=3, n_filters=16, padding=1, stride=2, bias=False)
self.convbnrelu1_2 = conv2DBatchNormRelu(
in_channels=16, k_size=3, n_filters=16, padding=1, stride=1, bias=False)
self.convbnrelu1_3 = conv2DBatchNormRelu(
in_channels=16, k_size=3, n_filters=32, padding=1, stride=1, bias=False)
# Vanilla Residual Blocks
self.res_block3 = self._make_layer(residualBlock, 64, 1, stride=2)
self.res_block5 = self._make_layer(residualBlock, 128, 1, stride=2)
self.res_block6 = self._make_layer(residualBlock, 128, 1, stride=2)
self.res_block7 = self._make_layer(residualBlock, 128, 1, stride=2)
self.pyramid_pooling = pyramidPooling(
128, None, fusion_mode='sum', model_name='icnet')
# Iconvs
self.upconv6 = nn.Sequential(
nn.Upsample(
scale_factor=2),
conv2DBatchNormRelu(
in_channels=128,
k_size=3,
n_filters=64,
padding=1,
stride=1,
bias=False))
self.iconv5 = conv2DBatchNormRelu(
in_channels=192,
k_size=3,
n_filters=128,
padding=1,
stride=1,
bias=False)
self.upconv5 = nn.Sequential(
nn.Upsample(
scale_factor=2),
conv2DBatchNormRelu(
in_channels=128,
k_size=3,
n_filters=64,
padding=1,
stride=1,
bias=False))
self.iconv4 = conv2DBatchNormRelu(
in_channels=192,
k_size=3,
n_filters=128,
padding=1,
stride=1,
bias=False)
self.upconv4 = nn.Sequential(
nn.Upsample(
scale_factor=2),
conv2DBatchNormRelu(
in_channels=128,
k_size=3,
n_filters=64,
padding=1,
stride=1,
bias=False))
self.iconv3 = conv2DBatchNormRelu(
in_channels=128,
k_size=3,
n_filters=64,
padding=1,
stride=1,
bias=False)
self.proj6 = conv2DBatchNormRelu(
in_channels=128,
k_size=1,
n_filters=32,
padding=0,
stride=1,
bias=False)
self.proj5 = conv2DBatchNormRelu(
in_channels=128,
k_size=1,
n_filters=16,
padding=0,
stride=1,
bias=False)
self.proj4 = conv2DBatchNormRelu(
in_channels=128,
k_size=1,
n_filters=16,
padding=0,
stride=1,
bias=False)
self.proj3 = conv2DBatchNormRelu(
in_channels=64,
k_size=1,
n_filters=16,
padding=0,
stride=1,
bias=False)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if hasattr(m.bias, 'data'):
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes *
block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(
planes *
block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
[docs] def forward(self, x):
# H, W -> H/2, W/2
conv1 = self.convbnrelu1_1(x)
conv1 = self.convbnrelu1_2(conv1)
conv1 = self.convbnrelu1_3(conv1)
# H/2, W/2 -> H/4, W/4
pool1 = F.max_pool2d(conv1, 3, 2, 1)
# H/4, W/4 -> H/16, W/16
conv3 = self.res_block3(pool1)
conv4 = self.res_block5(conv3)
conv5 = self.res_block6(conv4)
conv6 = self.res_block7(conv5)
conv6 = self.pyramid_pooling(conv6)
concat5 = torch.cat((conv5, self.upconv6(conv6)), dim=1)
conv5 = self.iconv5(concat5)
concat4 = torch.cat((conv4, self.upconv5(conv5)), dim=1)
conv4 = self.iconv4(concat4)
concat3 = torch.cat((conv3, self.upconv4(conv4)), dim=1)
conv3 = self.iconv3(concat3)
proj6 = self.proj6(conv6)
proj5 = self.proj5(conv5)
proj4 = self.proj4(conv4)
proj3 = self.proj3(conv3)
return proj6, proj5, proj4, proj3
[docs]class conv2DBatchNorm(nn.Module):
def __init__(
self,
in_channels,
n_filters,
k_size,
stride,
padding,
bias=True,
dilation=1,
with_bn=True):
super(conv2DBatchNorm, self).__init__()
if dilation > 1:
conv_mod = nn.Conv2d(
int(in_channels),
int(n_filters),
kernel_size=k_size,
padding=padding,
stride=stride,
bias=bias,
dilation=dilation)
else:
conv_mod = nn.Conv2d(
int(in_channels),
int(n_filters),
kernel_size=k_size,
padding=padding,
stride=stride,
bias=bias,
dilation=1)
if with_bn:
self.cb_unit = nn.Sequential(conv_mod,
nn.BatchNorm2d(int(n_filters)),)
else:
self.cb_unit = nn.Sequential(conv_mod,)
[docs] def forward(self, inputs):
outputs = self.cb_unit(inputs)
return outputs
[docs]class conv2DBatchNormRelu(nn.Module):
def __init__(
self,
in_channels,
n_filters,
k_size,
stride,
padding,
bias=True,
dilation=1,
with_bn=True):
super(conv2DBatchNormRelu, self).__init__()
if dilation > 1:
conv_mod = nn.Conv2d(
int(in_channels),
int(n_filters),
kernel_size=k_size,
padding=padding,
stride=stride,
bias=bias,
dilation=dilation)
else:
conv_mod = nn.Conv2d(
int(in_channels),
int(n_filters),
kernel_size=k_size,
padding=padding,
stride=stride,
bias=bias,
dilation=1)
if with_bn:
self.cbr_unit = nn.Sequential(conv_mod,
nn.BatchNorm2d(int(n_filters)),
nn.LeakyReLU(0.1, inplace=True),)
else:
self.cbr_unit = nn.Sequential(conv_mod,
nn.LeakyReLU(0.1, inplace=True),)
[docs] def forward(self, inputs):
outputs = self.cbr_unit(inputs)
return outputs
[docs]class residualBlock(nn.Module):
expansion = 1
def __init__(
self,
in_channels,
n_filters,
stride=1,
downsample=None,
dilation=1):
super(residualBlock, self).__init__()
if dilation > 1:
padding = dilation
else:
padding = 1
self.convbnrelu1 = conv2DBatchNormRelu(
in_channels,
n_filters,
3,
stride,
padding,
bias=False,
dilation=dilation)
self.convbn2 = conv2DBatchNorm(
n_filters, n_filters, 3, 1, 1, bias=False)
self.downsample = downsample
self.stride = stride
self.relu = nn.ReLU(inplace=True)
[docs] def forward(self, x):
residual = x
out = self.convbnrelu1(x)
out = self.convbn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
#out = self.relu(out)
return out
[docs]class pyramidPooling(nn.Module):
def __init__(
self,
in_channels,
pool_sizes,
model_name='pspnet',
fusion_mode='cat',
with_bn=True):
super(pyramidPooling, self).__init__()
bias = not with_bn
self.paths = []
if pool_sizes is None:
for _ in range(4):
self.paths.append(
conv2DBatchNormRelu(
in_channels,
in_channels,
1,
1,
0,
bias=bias,
with_bn=with_bn))
else:
for _ in range(len(pool_sizes)):
self.paths.append(
conv2DBatchNormRelu(
in_channels, int(
in_channels / len(pool_sizes)), 1, 1, 0, bias=bias, with_bn=with_bn))
self.path_module_list = nn.ModuleList(self.paths)
self.pool_sizes = pool_sizes
self.model_name = model_name
self.fusion_mode = fusion_mode
# @profile
[docs] def forward(self, x):
h, w = x.shape[2:]
k_sizes = []
strides = []
if self.pool_sizes is None:
for pool_size in np.linspace(1, min(h, w) // 2, 4, dtype=int):
k_sizes.append((int(h / pool_size), int(w / pool_size)))
strides.append((int(h / pool_size), int(w / pool_size)))
k_sizes = k_sizes[::-1]
strides = strides[::-1]
else:
k_sizes = [
(self.pool_sizes[0],
self.pool_sizes[0]),
(self.pool_sizes[1],
self.pool_sizes[1]),
(self.pool_sizes[2],
self.pool_sizes[2]),
(self.pool_sizes[3],
self.pool_sizes[3])]
strides = k_sizes
if self.fusion_mode == 'cat': # pspnet: concat (including x)
output_slices = [x]
for i, (module, pool_size) in enumerate(
zip(self.path_module_list, self.pool_sizes)):
out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
if self.model_name != 'icnet':
out = module(out)
out = F.upsample(out, size=(h, w), mode='bilinear')
output_slices.append(out)
return torch.cat(output_slices, dim=1)
else: # icnet: element-wise sum (including x)
pp_sum = x
for i, module in enumerate(self.path_module_list):
out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
out = module(out)
out = F.upsample(out, size=(h, w), mode='bilinear')
pp_sum = pp_sum + 0.25 * out
pp_sum = F.relu(pp_sum / 2., inplace=True)
return pp_sum