|
|
|
@@ -17,7 +17,7 @@ |
|
|
|
import math |
|
|
|
from functools import partial |
|
|
|
from mindspore._checkparam import _check_3d_int_or_tuple |
|
|
|
from .nn_ops import _check_positive_int_or_tuple, _update_attr_by_format |
|
|
|
from .nn_ops import _check_positive_int_or_tuple |
|
|
|
from .. import signature as sig |
|
|
|
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register |
|
|
|
from ..._checkparam import Validator as validator, Rel |
|
|
|
@@ -423,18 +423,21 @@ class Conv2DBackpropFilter(Primitive): |
|
|
|
validator.check_equal_int(len(pad), 4, 'pad size', self.name) |
|
|
|
self.add_prim_attr("pad", pad) |
|
|
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) |
|
|
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC": |
|
|
|
raise ValueError("NHWC format only support in GPU target.") |
|
|
|
self.add_prim_attr('data_format', self.format) |
|
|
|
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) |
|
|
|
self.stride = _update_attr_by_format(self.stride, self.format) |
|
|
|
self.add_prim_attr('stride', self.stride) |
|
|
|
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) |
|
|
|
self.dilation = _update_attr_by_format(self.dilation, self.format) |
|
|
|
self.add_prim_attr('dilation', self.dilation) |
|
|
|
self.dilation = dilation |
|
|
|
if context.get_context("device_target") != "GPU": |
|
|
|
if self.format == "NHWC": |
|
|
|
raise ValueError("NHWC format only support in GPU target.") |
|
|
|
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) |
|
|
|
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, |
|
|
|
ret_four=True) |
|
|
|
else: |
|
|
|
if isinstance(stride, tuple) and len(stride) == 4: |
|
|
|
self.stride = (stride[2], stride[3]) |
|
|
|
self.dilation = dilation |
|
|
|
self.group = group |
|
|
|
self.add_prim_attr('groups', group) |
|
|
|
self.add_prim_attr('stride', self.stride) |
|
|
|
self.add_prim_attr('dilation', self.dilation) |
|
|
|
self.add_prim_attr('data_format', self.format) |
|
|
|
|
|
|
|
|
|
|
|
class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer): |
|
|
|
|