This is a modified version of fastai's XResNet model in github. Changes include:
net = xresnet1d18plus(3, 2, coord=True)
x = torch.rand(32, 3, 50)
net(x)
bs, c_in, seq_len = 2, 4, 32
c_out = 2
x = torch.rand(bs, c_in, seq_len)
archs = [
xresnet1d18plus, xresnet1d34plus, xresnet1d50plus,
xresnet1d18_deepplus, xresnet1d34_deepplus, xresnet1d50_deepplus, xresnet1d18_deeperplus,
xresnet1d34_deeperplus, xresnet1d50_deeperplus
# # Long test
# xresnet1d101, xresnet1d152,
]
for i, arch in enumerate(archs):
print(i, arch.__name__)
test_eq(arch(c_in, c_out, sa=True, act=Mish, coord=True)(x).shape, (bs, c_out))
m = xresnet1d34plus(4, 2, act=Mish)
test_eq(len(get_layers(m, is_bn)), 38)
test_eq(check_weight(m, is_bn)[0].sum(), 22)