|
|
|
@@ -14,7 +14,10 @@ class Model(nn.Module): |
|
|
|
self.linear_1 = nn.Linear(in_features=16, out_features=13, bias=True) |
|
|
|
|
|
|
|
self.linear_2 = nn.Linear(in_features=13, out_features=17, bias=True) |
|
|
|
if version.parse(torch.__version__) < version.parse('2.1'): |
|
|
|
if version.parse(torch.__version__) < version.parse('1.9'): |
|
|
|
# weight_norm on torch 1.8 produces wrong output shape, skip it |
|
|
|
pass |
|
|
|
elif version.parse(torch.__version__) < version.parse('2.1'): |
|
|
|
self.linear_2 = torch.nn.utils.weight_norm(self.linear_2) |
|
|
|
else: |
|
|
|
self.linear_2 = torch.nn.utils.parametrizations.weight_norm(self.linear_2) |
|
|
|
|