|
|
|
@@ -15,8 +15,8 @@ |
|
|
|
|
|
|
|
"""Operators for quantization.""" |
|
|
|
|
|
|
|
from ..._checkparam import ParamValidator as validator |
|
|
|
from ..._checkparam import Rel, check_bool, check_int_positive, check_int |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ..._checkparam import Rel |
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register |
|
|
|
from ...common import dtype as mstype |
|
|
|
|
|
|
|
@@ -69,36 +69,31 @@ class FakeQuantWithMinMax(PrimitiveWithInfer): |
|
|
|
training=True): |
|
|
|
"""init FakeQuantWithMinMax OP""" |
|
|
|
if num_bits not in self.support_quant_bit: |
|
|
|
raise ValueError("Attr \'num_bits\' is not support.") |
|
|
|
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") |
|
|
|
if ema and not ema_decay: |
|
|
|
raise ValueError( |
|
|
|
"Attr \'ema\' and \'ema_decay\' should set together.") |
|
|
|
|
|
|
|
self.ema = check_bool(ema) |
|
|
|
self.symmetric = check_bool(symmetric) |
|
|
|
self.narrow_range = check_bool(narrow_range) |
|
|
|
self.training = check_bool(training) |
|
|
|
self.ema_decay = validator.check_number_range( |
|
|
|
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH) |
|
|
|
self.num_bits = check_int_positive(num_bits) |
|
|
|
self.quant_delay = check_int(quant_delay) |
|
|
|
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") |
|
|
|
|
|
|
|
self.ema = validator.check_value_type('ema', ema, (bool,), self.name) |
|
|
|
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) |
|
|
|
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) |
|
|
|
self.training = validator.check_value_type('training', training, (bool,), self.name) |
|
|
|
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) |
|
|
|
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) |
|
|
|
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) |
|
|
|
self.init_prim_io_names(inputs=['x', 'min', 'max'], |
|
|
|
outputs=['out']) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, min_shape, max_shape): |
|
|
|
validator.check_integer("x shape", len(x_shape), 1, Rel.GT) |
|
|
|
validator.check("min shape", min_shape, "max shape", max_shape) |
|
|
|
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ) |
|
|
|
validator.check_integer("max shape", len(min_shape), 1, Rel.EQ) |
|
|
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) |
|
|
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) |
|
|
|
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type, min_type, max_type): |
|
|
|
validator.check_typename( |
|
|
|
"x type", x_type, (mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename("min type", min_type, |
|
|
|
(mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename("max type", max_type, |
|
|
|
(mstype.float16, mstype.float32)) |
|
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -109,29 +104,24 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, num_bits=8, quant_delay=0): |
|
|
|
if num_bits not in self.support_quant_bit: |
|
|
|
raise ValueError("Attr \'num_bits\' is not support.") |
|
|
|
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") |
|
|
|
|
|
|
|
self.quant_delay = check_int(quant_delay) |
|
|
|
self.num_bits = check_int_positive(num_bits) |
|
|
|
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], |
|
|
|
outputs=['dx']) |
|
|
|
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) |
|
|
|
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) |
|
|
|
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) |
|
|
|
|
|
|
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): |
|
|
|
validator.check("dout shape", dout_shape, "x shape", x_shape) |
|
|
|
validator.check("min shape", min_shape, "max shape", max_shape) |
|
|
|
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ) |
|
|
|
validator.check_integer("max shape", len(min_shape), 1, Rel.EQ) |
|
|
|
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) |
|
|
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) |
|
|
|
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) |
|
|
|
return dout_shape |
|
|
|
|
|
|
|
def infer_dtype(self, dout_type, x_type, min_type, max_type): |
|
|
|
validator.check_typename( |
|
|
|
"dout type", dout_type, (mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename( |
|
|
|
"x type", x_type, (mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename("min type", min_type, |
|
|
|
(mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename("max type", max_type, |
|
|
|
(mstype.float16, mstype.float32)) |
|
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
|
validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) |
|
|
|
return dout_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -172,37 +162,30 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): |
|
|
|
training=True): |
|
|
|
"""init FakeQuantWithMinMaxPerChannel OP""" |
|
|
|
if num_bits not in self.support_quant_bit: |
|
|
|
raise ValueError("Attr \'num_bits\' is not support.") |
|
|
|
raise ValueError(f"For '{self.name}' Attr \'num_bits\' is not support.") |
|
|
|
if ema and not ema_decay: |
|
|
|
raise ValueError( |
|
|
|
"Attr \'ema\' and \'ema_decay\' should set together.") |
|
|
|
|
|
|
|
self.ema = check_bool(ema) |
|
|
|
self.symmetric = check_bool(symmetric) |
|
|
|
self.narrow_range = check_bool(narrow_range) |
|
|
|
self.training = check_bool(training) |
|
|
|
self.ema_decay = validator.check_number_range( |
|
|
|
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH) |
|
|
|
self.num_bits = check_int_positive(num_bits) |
|
|
|
self.quant_delay = check_int(quant_delay) |
|
|
|
self.init_prim_io_names(inputs=['x', 'min', 'max'], |
|
|
|
outputs=['out']) |
|
|
|
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") |
|
|
|
|
|
|
|
self.ema = validator.check_value_type('ema', ema, (bool,), self.name) |
|
|
|
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) |
|
|
|
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) |
|
|
|
self.training = validator.check_value_type('training', training, (bool,), self.name) |
|
|
|
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) |
|
|
|
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) |
|
|
|
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) |
|
|
|
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, min_shape, max_shape): |
|
|
|
validator.check_integer("x shape", len(x_shape), 1, Rel.GT) |
|
|
|
validator.check_integer( |
|
|
|
"min len", min_shape[0], x_shape[self.channel_idx], Rel.EQ) |
|
|
|
validator.check_integer( |
|
|
|
"max len", max_shape[0], x_shape[self.channel_idx], Rel.EQ) |
|
|
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) |
|
|
|
validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name) |
|
|
|
validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name) |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type, min_type, max_type): |
|
|
|
validator.check_typename( |
|
|
|
"x type", x_type, (mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename("min type", min_type, |
|
|
|
(mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename("max type", max_type, |
|
|
|
(mstype.float16, mstype.float32)) |
|
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -214,12 +197,11 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): |
|
|
|
def __init__(self, num_bits=8, quant_delay=0): |
|
|
|
"""init FakeQuantWithMinMaxPerChannel Fill""" |
|
|
|
if num_bits not in self.support_quant_bit: |
|
|
|
raise ValueError("Attr \'num_bits\' is not support.") |
|
|
|
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") |
|
|
|
|
|
|
|
self.quant_delay = check_int(quant_delay) |
|
|
|
self.num_bits = check_int_positive(num_bits) |
|
|
|
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], |
|
|
|
outputs=['dx']) |
|
|
|
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) |
|
|
|
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) |
|
|
|
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) |
|
|
|
|
|
|
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): |
|
|
|
validator.check("dout shape", dout_shape, "x shape", x_shape) |
|
|
|
@@ -227,13 +209,11 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): |
|
|
|
return dout_shape |
|
|
|
|
|
|
|
def infer_dtype(self, dout_type, x_type, min_type, max_type): |
|
|
|
validator.check_typename( |
|
|
|
"dout", dout_type, (mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename("x", x_type, (mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename( |
|
|
|
"min", min_type, (mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename( |
|
|
|
"max", max_type, (mstype.float16, mstype.float32)) |
|
|
|
valid_types = (mstype.float16, mstype.float32) |
|
|
|
validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) |
|
|
|
return dout_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -269,31 +249,26 @@ class BatchNormFold(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0): |
|
|
|
"""init batch norm fold layer""" |
|
|
|
self.momentum = validator.check_number_range( |
|
|
|
'momentum', momentum, 0, 1, Rel.INC_BOTH) |
|
|
|
self.epsilon = validator.check_float_positive('epsilon', epsilon) |
|
|
|
self.is_training = check_bool(is_training) |
|
|
|
self.freeze_bn = check_int(freeze_bn) |
|
|
|
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) |
|
|
|
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) |
|
|
|
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) |
|
|
|
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) |
|
|
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'], |
|
|
|
outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std']) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): |
|
|
|
validator.check("mean shape", mean_shape, |
|
|
|
"gamma_shape", variance_shape) |
|
|
|
validator.check("mean_shape size", |
|
|
|
mean_shape[0], "input channel", x_shape[self.channel]) |
|
|
|
validator.check_integer("global_step shape", |
|
|
|
len(global_step_shape), 1, Rel.EQ) |
|
|
|
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) |
|
|
|
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, self.name) |
|
|
|
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) |
|
|
|
return mean_shape, mean_shape, mean_shape, mean_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): |
|
|
|
validator.check("input type", x_type, "mean type", mean_type) |
|
|
|
validator.check("input type", x_type, "variance type", variance_type) |
|
|
|
validator.check_typename("input type", x_type, |
|
|
|
(mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename( |
|
|
|
"global_step type", global_step_type, (mstype.int32,)) |
|
|
|
args = {"x": x_type, "mean": mean_type, "variance": variance_type} |
|
|
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) |
|
|
|
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) |
|
|
|
return x_type, x_type, x_type, x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -304,39 +279,31 @@ class BatchNormFoldGrad(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0): |
|
|
|
"""init BatchNormGrad layer""" |
|
|
|
self.is_training = check_bool(is_training) |
|
|
|
self.freeze_bn = check_int(freeze_bn) |
|
|
|
self.epsilon = validator.check_float_positive('epsilon', epsilon) |
|
|
|
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) |
|
|
|
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) |
|
|
|
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) |
|
|
|
self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'], |
|
|
|
outputs=['dx']) |
|
|
|
|
|
|
|
def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape, |
|
|
|
global_step_shape): |
|
|
|
validator.check("d_batch_mean shape", d_batch_mean_shape, |
|
|
|
"d_batch_std shape", d_batch_std_shape) |
|
|
|
"d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name) |
|
|
|
validator.check("d_batch_mean shape", d_batch_mean_shape, |
|
|
|
"batch_mean shape", batch_mean_shape) |
|
|
|
"batch_mean shape", batch_mean_shape, Rel.EQ, self.name) |
|
|
|
validator.check("d_batch_mean shape", d_batch_mean_shape, |
|
|
|
"batch_std shape", batch_std_shape) |
|
|
|
validator.check( |
|
|
|
"x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[self.channel]) |
|
|
|
validator.check_integer("global_step shape", |
|
|
|
len(global_step_shape), 1, Rel.EQ) |
|
|
|
"batch_std shape", batch_std_shape, Rel.EQ, self.name) |
|
|
|
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, |
|
|
|
self.name) |
|
|
|
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, |
|
|
|
global_step_type): |
|
|
|
validator.check("input type", x_type, |
|
|
|
"d_batch_mean type", d_batch_mean_type) |
|
|
|
validator.check("input type", x_type, |
|
|
|
"d_batch_std type", d_batch_std_type) |
|
|
|
validator.check("input type", x_type, |
|
|
|
"batch_mean type", batch_mean_type) |
|
|
|
validator.check("input type", x_type, "batch_std type", batch_std_type) |
|
|
|
validator.check_typename("input type", x_type, |
|
|
|
(mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename( |
|
|
|
"global_step type", global_step_type, (mstype.int32,)) |
|
|
|
args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type, |
|
|
|
"batch_mean": batch_mean_type, "batch_std": batch_std_type} |
|
|
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) |
|
|
|
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -364,18 +331,14 @@ class CorrectionMul(PrimitiveWithInfer): |
|
|
|
outputs=['out']) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, batch_std_shape, running_std_shape): |
|
|
|
validator.check("batch_std shape", batch_std_shape, |
|
|
|
"running_std shape", running_std_shape) |
|
|
|
validator.check( |
|
|
|
"batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel]) |
|
|
|
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) |
|
|
|
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel], |
|
|
|
Rel.EQ, self.name) |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type, batch_std_type, running_std_type): |
|
|
|
validator.check("batch_std type", batch_std_type, |
|
|
|
"running_std type", running_std_type) |
|
|
|
validator.check("batch_std_type", batch_std_type, "x_type", x_type) |
|
|
|
validator.check_typename( |
|
|
|
"batch_std type", batch_std_type, (mstype.float16, mstype.float32)) |
|
|
|
args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type} |
|
|
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -390,20 +353,16 @@ class CorrectionMulGrad(PrimitiveWithInfer): |
|
|
|
outputs=['dx', 'd_gamma']) |
|
|
|
|
|
|
|
def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape): |
|
|
|
validator.check("dout shape", dout_shape, "x_shape x", x_shape) |
|
|
|
validator.check( |
|
|
|
"gamma size", gamma_shape[0], "dout channel size", dout_shape[self.channel]) |
|
|
|
validator.check( |
|
|
|
"running_std size", running_std_shape[0], "dout channel size", dout_shape[self.channel]) |
|
|
|
validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name) |
|
|
|
validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel], |
|
|
|
Rel.EQ, self.name) |
|
|
|
validator.check("running_std_shape[0]", running_std_shape[0], "dout channel size", dout_shape[self.channel], |
|
|
|
Rel.EQ, self.name) |
|
|
|
return x_shape, gamma_shape |
|
|
|
|
|
|
|
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): |
|
|
|
validator.check("x type", x_type, "dout type", dout_type) |
|
|
|
validator.check("gamma type", gamma_type, "dout type", dout_type) |
|
|
|
validator.check("running_std type", running_std_type, |
|
|
|
"dout type", dout_type) |
|
|
|
validator.check_typename( |
|
|
|
"dout type", dout_type, (mstype.float16, mstype.float32)) |
|
|
|
args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} |
|
|
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) |
|
|
|
return x_type, x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -432,46 +391,29 @@ class BatchNormFold2(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, freeze_bn=0): |
|
|
|
"""init conv2d fold layer""" |
|
|
|
self.freeze_bn = check_int(freeze_bn) |
|
|
|
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) |
|
|
|
self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', |
|
|
|
'running_std', 'running_mean', 'global_step'], |
|
|
|
outputs=['y']) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape, |
|
|
|
running_mean_shape, global_step_shape): |
|
|
|
validator.check("batch_std shape", batch_std_shape, |
|
|
|
"running_std shape", running_std_shape) |
|
|
|
validator.check("batch_std shape", batch_std_shape, |
|
|
|
"batch_mean shape", batch_mean_shape) |
|
|
|
validator.check("batch_std shape", batch_std_shape, |
|
|
|
"beta shape", beta_shape) |
|
|
|
validator.check("batch_std shape", batch_std_shape, |
|
|
|
"running_mean shape", running_mean_shape) |
|
|
|
validator.check("batch_std shape", batch_std_shape, |
|
|
|
"batch_mean shape", gamma_shape) |
|
|
|
validator.check( |
|
|
|
"batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel]) |
|
|
|
validator.check_integer("global_step shape", |
|
|
|
len(global_step_shape), 1, Rel.EQ) |
|
|
|
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) |
|
|
|
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) |
|
|
|
validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name) |
|
|
|
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name) |
|
|
|
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) |
|
|
|
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel], |
|
|
|
Rel.EQ, self.name) |
|
|
|
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, |
|
|
|
running_mean_type, global_step_type): |
|
|
|
validator.check("batch_std type", batch_std_type, |
|
|
|
"running_std type", running_std_type) |
|
|
|
validator.check("batch_std type", batch_std_type, |
|
|
|
"batch_mean type", batch_mean_type) |
|
|
|
validator.check("batch_std type", batch_std_type, |
|
|
|
"beta type", beta_type) |
|
|
|
validator.check("batch_std type", batch_std_type, |
|
|
|
"running_mean type", running_mean_type) |
|
|
|
validator.check("batch_std type", batch_std_type, |
|
|
|
"gamma type", gamma_type) |
|
|
|
validator.check("x_type", x_type, "batch_std type", batch_std_type) |
|
|
|
validator.check_typename( |
|
|
|
"batch_std type", batch_std_type, (mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename( |
|
|
|
"global_step type", global_step_type, (mstype.int32,)) |
|
|
|
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, |
|
|
|
"beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type} |
|
|
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) |
|
|
|
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -491,18 +433,13 @@ class BatchNormFold2Grad(PrimitiveWithInfer): |
|
|
|
def infer_shape(self, dout_shape, x_shape, gamma_shape, |
|
|
|
batch_std_shape, batch_mean_shape, |
|
|
|
running_std_shape, running_mean_shape, global_step_shape): |
|
|
|
validator.check("batch_std shape", batch_std_shape, |
|
|
|
"batch_mean shape", batch_mean_shape) |
|
|
|
validator.check("batch_std shape", batch_std_shape, |
|
|
|
"running_std shape", running_std_shape) |
|
|
|
validator.check("batch_std shape", batch_std_shape, |
|
|
|
"running_mean shape", running_mean_shape) |
|
|
|
validator.check("batch_std shape", batch_std_shape, |
|
|
|
"gamma shape", gamma_shape) |
|
|
|
validator.check( |
|
|
|
"batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel]) |
|
|
|
validator.check_integer("global_step shape", |
|
|
|
len(global_step_shape), 1, Rel.EQ) |
|
|
|
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) |
|
|
|
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) |
|
|
|
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name) |
|
|
|
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) |
|
|
|
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel], |
|
|
|
Rel.EQ, self.name) |
|
|
|
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) |
|
|
|
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, dout_type, x_type, gamma_type, |
|
|
|
@@ -518,8 +455,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer): |
|
|
|
"running_mean type", running_mean_type) |
|
|
|
validator.check("batch_std_type", batch_std_type, |
|
|
|
"dout type", dout_type) |
|
|
|
validator.check_typename( |
|
|
|
"batch_std type", batch_std_type, (mstype.float16, mstype.float32)) |
|
|
|
validator.check_typename( |
|
|
|
"global_step type", global_step_type, (mstype.int32,)) |
|
|
|
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, |
|
|
|
"running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type} |
|
|
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) |
|
|
|
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) |
|
|
|
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type |