Browse Source

make conv2d bp filter stride attr len 4

tags/v0.3.0-alpha
zhaozhenlong 5 years ago
parent
commit
4c37420890
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      mindspore/ops/operations/_grad_ops.py

+ 3
- 3
mindspore/ops/operations/_grad_ops.py View File

@@ -174,9 +174,9 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
pad_mode = pad_mode.upper()
self.add_prim_attr('pad_mode', pad_mode)
self.pad = pad
if isinstance(stride, tuple) and len(stride) == 4:
self.stride = (stride[2], stride[3])
self.add_prim_attr('stride', self.stride)
if isinstance(stride, tuple) and len(stride) == 2:
self.stride = stride
self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1]))
self.dilation = dilation
self.group = group
self.add_prim_attr('data_format', "NCHW")


Loading…
Cancel
Save