|
|
@@ -467,7 +467,10 @@ class FakeQuantPerLayer(PrimitiveWithInfer): |
|
|
return x_shape |
|
|
return x_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, x_type, min_type, max_type): |
|
|
def infer_dtype(self, x_type, min_type, max_type): |
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
|
|
|
|
|
|
if context.get_context('device_target') == "GPU": |
|
|
|
|
|
valid_types = (mstype.float32,) |
|
|
|
|
|
else: |
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
validator.check_tensor_type_same( |
|
|
validator.check_tensor_type_same( |
|
|
{"min": min_type}, valid_types, self.name) |
|
|
{"min": min_type}, valid_types, self.name) |
|
|
@@ -521,7 +524,10 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): |
|
|
return dout_shape |
|
|
return dout_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, dout_type, x_type, min_type, max_type): |
|
|
def infer_dtype(self, dout_type, x_type, min_type, max_type): |
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
|
|
|
|
|
|
if context.get_context('device_target') == "GPU": |
|
|
|
|
|
valid_types = (mstype.float32,) |
|
|
|
|
|
else: |
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
validator.check_tensor_type_same( |
|
|
validator.check_tensor_type_same( |
|
|
{"dout": dout_type}, valid_types, self.name) |
|
|
{"dout": dout_type}, valid_types, self.name) |
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
@@ -616,7 +622,10 @@ class FakeQuantPerChannel(PrimitiveWithInfer): |
|
|
return x_shape |
|
|
return x_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, x_type, min_type, max_type): |
|
|
def infer_dtype(self, x_type, min_type, max_type): |
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
|
|
|
|
|
|
if context.get_context('device_target') == "GPU": |
|
|
|
|
|
valid_types = (mstype.float32,) |
|
|
|
|
|
else: |
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
validator.check_tensor_type_same( |
|
|
validator.check_tensor_type_same( |
|
|
{"min": min_type}, valid_types, self.name) |
|
|
{"min": min_type}, valid_types, self.name) |
|
|
@@ -670,7 +679,10 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer): |
|
|
return dout_shape |
|
|
return dout_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, dout_type, x_type, min_type, max_type): |
|
|
def infer_dtype(self, dout_type, x_type, min_type, max_type): |
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
|
|
|
|
|
|
if context.get_context('device_target') == "GPU": |
|
|
|
|
|
valid_types = (mstype.float32,) |
|
|
|
|
|
else: |
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
validator.check_tensor_type_same( |
|
|
validator.check_tensor_type_same( |
|
|
{"dout": dout_type}, valid_types, self.name) |
|
|
{"dout": dout_type}, valid_types, self.name) |
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
|