| @@ -250,6 +250,10 @@ def _is_equal_one(x): | |||||
| return False | return False | ||||
| return bool(x.asnumpy().mean() == 1.0) | return bool(x.asnumpy().mean() == 1.0) | ||||
| @constexpr | |||||
| def _dtype_check(x_dtype): | |||||
| if x_dtype not in [mstype.float32, mstype.float16]: | |||||
| raise TypeError("The input type must be float32 or float16.") | |||||
| class ClipByNorm(Cell): | class ClipByNorm(Cell): | ||||
| r""" | r""" | ||||
| @@ -264,12 +268,11 @@ class ClipByNorm(Cell): | |||||
| where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. | where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. | ||||
| Inputs: | Inputs: | ||||
| - **input** (Tensor) - Tensor of shape N-D. | |||||
| - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)` and of | |||||
| the same type as the input Tensor. | |||||
| - **input** (Tensor) - Tensor of shape N-D. The type should be float32 or float16. | |||||
| - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`. | |||||
| Outputs: | Outputs: | ||||
| Tensor, clipped tensor with the same shape as the input. | |||||
| Tensor, clipped tensor with the same shape as the input, whose type is float32. | |||||
| Examples: | Examples: | ||||
| >>> net = nn.ClipByNorm() | >>> net = nn.ClipByNorm() | ||||
| @@ -300,10 +303,10 @@ class ClipByNorm(Cell): | |||||
| l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) | l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) | ||||
| cond = self.greater_(l2sum, 0) | cond = self.greater_(l2sum, 0) | ||||
| ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) | ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) | ||||
| l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) | l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) | ||||
| l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) | l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) | ||||
| _dtype_check(self.dtype(x)) | |||||
| if _is_equal_one(clip_norm): | if _is_equal_one(clip_norm): | ||||
| intermediate = x | intermediate = x | ||||
| else: | else: | ||||
| @@ -827,13 +827,3 @@ def get_bprop_unique(self): | |||||
| dx = op(dout, out) | dx = op(dout, out) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.UnsortedSegmentSum) | |||||
| def get_bprop_unsorted_segment_sum(self): | |||||
| """Generate bprop for UnsortedSegmentSum""" | |||||
| op = G.UnsortedSegmentSumGrad() | |||||
| def bprop(x, segment_ids, num_segments, out, dout): | |||||
| dx = op(dout, segment_ids) | |||||
| return (dx, zeros_like(segment_ids), zeros_like(num_segments)) | |||||
| return bprop | |||||
| @@ -502,20 +502,6 @@ class UniqueGrad(Primitive): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| class UnsortedSegmentSumGrad(PrimitiveWithInfer): | |||||
| """Gradients of UnsortedSegmentSum operation.""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| self.init_prim_io_names(inputs=['grads', 'ids'], outputs=['y']) | |||||
| def infer_shape(self, grads, ids): | |||||
| return ids + grads[len(ids):] | |||||
| def infer_dtype(self, grads, ids): | |||||
| return grads | |||||
| class BNTrainingReduceGrad(PrimitiveWithInfer): | class BNTrainingReduceGrad(PrimitiveWithInfer): | ||||
| """Gradients of FusedBatchNorm operation.""" | """Gradients of FusedBatchNorm operation.""" | ||||
| @@ -93,8 +93,12 @@ class BoundingBoxEncode(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): | def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): | ||||
| validator.check_value_type('means', means, [tuple], self.name) | |||||
| validator.check_value_type('stds', stds, [tuple], self.name) | |||||
| validator.check_value_type('means', means, [tuple, list], self.name) | |||||
| validator.check_value_type('stds', stds, [tuple, list], self.name) | |||||
| for i, value in enumerate(means): | |||||
| validator.check_value_type("means[%d]" % i, value, [float], self.name) | |||||
| for i, value in enumerate(stds): | |||||
| validator.check_value_type("stds[%d]" % i, value, [float], self.name) | |||||
| validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) | validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) | ||||
| validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) | validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) | ||||
| @@ -143,8 +147,12 @@ class BoundingBoxDecode(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016): | def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016): | ||||
| validator.check_value_type('means', means, [tuple], self.name) | |||||
| validator.check_value_type('stds', stds, [tuple], self.name) | |||||
| validator.check_value_type('means', means, [tuple, list], self.name) | |||||
| validator.check_value_type('stds', stds, [tuple, list], self.name) | |||||
| for i, value in enumerate(means): | |||||
| validator.check_value_type("means[%d]" % i, value, [float], self.name) | |||||
| for i, value in enumerate(stds): | |||||
| validator.check_value_type("stds[%d]" % i, value, [float], self.name) | |||||
| validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name) | validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name) | ||||
| validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) | validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) | ||||
| validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) | validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) | ||||