Browse Source

modify act ulq min/max grad type

tags/v1.1.0
y00369862 5 years ago
parent
commit
e4e5925dae
1 changed files with 8 additions and 3 deletions
  1. +8
    -3
      mindspore/ops/operations/_quant_ops.py

+ 8
- 3
mindspore/ops/operations/_quant_ops.py View File

@@ -44,7 +44,12 @@ __all__ = ["MinMaxUpdatePerLayer",
"BatchNormFold2_D", "BatchNormFold2_D",
"BatchNormFold2GradD", "BatchNormFold2GradD",
"BatchNormFold2GradReduce", "BatchNormFold2GradReduce",
"IFMR"
"IFMR",
"ActsULQ",
"ActsULQInputGrad",
"ActULQClampMinGrad",
"ActULQClampMaxGrad",
"WtsARQ"
] ]




@@ -1300,7 +1305,7 @@ class ActULQClampMinGrad(PrimitiveWithInfer):
return tuple(output_shape) return tuple(output_shape)


def infer_dtype(self, input_x, input_y, input_z): def infer_dtype(self, input_x, input_y, input_z):
return input_x
return mstype.float32




class ActULQClampMaxGrad(PrimitiveWithInfer): class ActULQClampMaxGrad(PrimitiveWithInfer):
@@ -1337,7 +1342,7 @@ class ActULQClampMaxGrad(PrimitiveWithInfer):
return tuple(output_shape) return tuple(output_shape)


def infer_dtype(self, input_x, input_y, input_z): def infer_dtype(self, input_x, input_y, input_z):
return input_x
return mstype.float32




class WtsARQ(PrimitiveWithInfer): class WtsARQ(PrimitiveWithInfer):


Loading…
Cancel
Save