|
|
|
@@ -44,7 +44,12 @@ __all__ = ["MinMaxUpdatePerLayer", |
|
|
|
"BatchNormFold2_D", |
|
|
|
"BatchNormFold2GradD", |
|
|
|
"BatchNormFold2GradReduce", |
|
|
|
"IFMR" |
|
|
|
"IFMR", |
|
|
|
"ActsULQ", |
|
|
|
"ActsULQInputGrad", |
|
|
|
"ActULQClampMinGrad", |
|
|
|
"ActULQClampMaxGrad", |
|
|
|
"WtsARQ" |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
@@ -1300,7 +1305,7 @@ class ActULQClampMinGrad(PrimitiveWithInfer): |
|
|
|
return tuple(output_shape) |
|
|
|
|
|
|
|
def infer_dtype(self, input_x, input_y, input_z): |
|
|
|
return input_x |
|
|
|
return mstype.float32 |
|
|
|
|
|
|
|
|
|
|
|
class ActULQClampMaxGrad(PrimitiveWithInfer): |
|
|
|
@@ -1337,7 +1342,7 @@ class ActULQClampMaxGrad(PrimitiveWithInfer): |
|
|
|
return tuple(output_shape) |
|
|
|
|
|
|
|
def infer_dtype(self, input_x, input_y, input_z): |
|
|
|
return input_x |
|
|
|
return mstype.float32 |
|
|
|
|
|
|
|
|
|
|
|
class WtsARQ(PrimitiveWithInfer): |
|
|
|
|