Browse Source

!8081 add float16 check for gpu fakequant op

Merge pull request !8081 from yuchaojie/quant2
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
fc5a3b7d97
1 changed files with 16 additions and 4 deletions
  1. +16
    -4
      mindspore/ops/operations/_quant_ops.py

+ 16
- 4
mindspore/ops/operations/_quant_ops.py View File

@@ -467,7 +467,10 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same( validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name) {"min": min_type}, valid_types, self.name)
@@ -521,7 +524,10 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
return dout_shape return dout_shape


def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same( validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name) {"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
@@ -616,7 +622,10 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same( validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name) {"min": min_type}, valid_types, self.name)
@@ -670,7 +679,10 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
return dout_shape return dout_shape


def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same( validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name) {"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)


Loading…
Cancel
Save