diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index ad1c90d19a..f6df84f18e 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -467,7 +467,10 @@ class FakeQuantPerLayer(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) + if context.get_context('device_target') == "GPU": + valid_types = (mstype.float32,) + else: + valid_types = (mstype.float16, mstype.float32) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same( {"min": min_type}, valid_types, self.name) @@ -521,7 +524,10 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): return dout_shape def infer_dtype(self, dout_type, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) + if context.get_context('device_target') == "GPU": + valid_types = (mstype.float32,) + else: + valid_types = (mstype.float16, mstype.float32) validator.check_tensor_type_same( {"dout": dout_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) @@ -616,7 +622,10 @@ class FakeQuantPerChannel(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) + if context.get_context('device_target') == "GPU": + valid_types = (mstype.float32,) + else: + valid_types = (mstype.float16, mstype.float32) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same( {"min": min_type}, valid_types, self.name) @@ -670,7 +679,10 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer): return dout_shape def infer_dtype(self, dout_type, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) + if context.get_context('device_target') == "GPU": + valid_types = (mstype.float32,) + else: + valid_types = (mstype.float16, mstype.float32) validator.check_tensor_type_same( {"dout": dout_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)