|
|
|
@@ -174,9 +174,9 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): |
|
|
|
pad_mode = pad_mode.upper() |
|
|
|
self.add_prim_attr('pad_mode', pad_mode) |
|
|
|
self.pad = pad |
|
|
|
if isinstance(stride, tuple) and len(stride) == 4: |
|
|
|
self.stride = (stride[2], stride[3]) |
|
|
|
self.add_prim_attr('stride', self.stride) |
|
|
|
if isinstance(stride, tuple) and len(stride) == 2: |
|
|
|
self.stride = stride |
|
|
|
self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) |
|
|
|
self.dilation = dilation |
|
|
|
self.group = group |
|
|
|
self.add_prim_attr('data_format', "NCHW") |
|
|
|
|