Browse Source

!8412 modify ActULQClampMinGrad/ActULQClampMaxGradrad return type

From: @yinding
Reviewed-by: @liangchenghui
Signed-off-by: @liangchenghui
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
5b84ac6847
1 changed files with 15 additions and 10 deletions
  1. +15
    -10
      mindspore/ops/operations/_quant_ops.py

+ 15
- 10
mindspore/ops/operations/_quant_ops.py View File

@@ -44,7 +44,12 @@ __all__ = ["MinMaxUpdatePerLayer",
"BatchNormFold2_D",
"BatchNormFold2GradD",
"BatchNormFold2GradReduce",
"IFMR"
"IFMR",
"ActsULQ",
"ActsULQInputGrad",
"ActULQClampMinGrad",
"ActULQClampMaxGrad",
"WtsARQ"
]


@@ -1236,9 +1241,9 @@ class ActsULQ(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, clamp_min_dtype, clamp_max_dtype):
"""infer dtype of primitive"""
valid_types = [mstype.float32, mstype.float16]
validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"clamp_min": clamp_min_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"clamp_max": clamp_max_dtype}, valid_types, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("clamp_min", clamp_min_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("clamp_max", clamp_max_dtype, valid_types, self.name)

return x_dtype, mstype.bool_, mstype.bool_, x_dtype

@@ -1262,7 +1267,7 @@ class ActsULQInputGrad(PrimitiveWithInfer):

def infer_dtype(self, y_grad_type, clamp_min_mask_type, clamp_max_mask_type):
valid_types = [mstype.float32, mstype.float16]
validator.check_tensor_type_same({"y_grad": y_grad_type}, valid_types, self.name)
validator.check_tensor_dtype_valid("y_grad", y_grad_type, valid_types, self.name)
return y_grad_type


@@ -1300,7 +1305,7 @@ class ActULQClampMinGrad(PrimitiveWithInfer):
return tuple(output_shape)

def infer_dtype(self, input_x, input_y, input_z):
return input_x
return mstype.float32


class ActULQClampMaxGrad(PrimitiveWithInfer):
@@ -1337,7 +1342,7 @@ class ActULQClampMaxGrad(PrimitiveWithInfer):
return tuple(output_shape)

def infer_dtype(self, input_x, input_y, input_z):
return input_x
return mstype.float32


class WtsARQ(PrimitiveWithInfer):
@@ -1381,9 +1386,9 @@ class WtsARQ(PrimitiveWithInfer):

def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype):
valid_types = [mstype.float32, mstype.float16]
validator.check_tensor_type_same({"w": w_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"w_min": w_min_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"w_max": w_max_dtype}, valid_types, self.name)
validator.check_tensor_dtype_valid("w", w_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("w_min", w_min_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("w_max", w_max_dtype, valid_types, self.name)
return w_dtype




Loading…
Cancel
Save