From: @chenlei_autodiff Reviewed-by: @gaoxiong1,@dylangeng Signed-off-by: @dylangengpull/15097/MERGE
| @@ -23,11 +23,13 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx | |||||
| def create_expander(expand_info): | def create_expander(expand_info): | ||||
| """Create an expander according to op name""" | """Create an expander according to op name""" | ||||
| def call_func(func, arg): | |||||
| return func(arg) | |||||
| op_name = str(expand_info['name']) | op_name = str(expand_info['name']) | ||||
| if not hasattr(expanders, op_name): | if not hasattr(expanders, op_name): | ||||
| raise GraphKernelUnsupportedException("Generator do not support op: {}".format(op_name)) | raise GraphKernelUnsupportedException("Generator do not support op: {}".format(op_name)) | ||||
| expander = getattr(expanders, op_name) | expander = getattr(expanders, op_name) | ||||
| return expander(expand_info) | |||||
| return call_func(expander, expand_info) | |||||
| def extract_expand_info(kernel_info): | def extract_expand_info(kernel_info): | ||||
| @@ -66,19 +66,18 @@ class Expander: | |||||
| class ExpanderInfoValidator: | class ExpanderInfoValidator: | ||||
| """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders""" | """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders""" | ||||
| # pylint: disable=W0211 | |||||
| @staticmethod | @staticmethod | ||||
| def _add_check_function(cls, func): | |||||
| def _add_check_function(kls, func): | |||||
| """ | """ | ||||
| Rewrite the function `_check` in class Expander | Rewrite the function `_check` in class Expander | ||||
| to append the new `func` after the original checks. | to append the new `func` after the original checks. | ||||
| """ | """ | ||||
| old_check = getattr(cls, "_check") | |||||
| old_check = getattr(kls, "_check") | |||||
| def new_check(obj): | def new_check(obj): | ||||
| old_check(obj) | old_check(obj) | ||||
| func(obj) | func(obj) | ||||
| setattr(cls, "_check", new_check) | |||||
| setattr(kls, "_check", new_check) | |||||
| @staticmethod | @staticmethod | ||||
| def add_format(*input_format): | def add_format(*input_format): | ||||
| @@ -112,7 +111,7 @@ class ExpanderInfoValidator: | |||||
| return wrapper | return wrapper | ||||
| @staticmethod | @staticmethod | ||||
| def check_all_formats_same(cls): | |||||
| def check_all_formats_same(kls): | |||||
| """Check that all formats are the same""" | """Check that all formats are the same""" | ||||
| def _check_format(obj): | def _check_format(obj): | ||||
| inp_formats = [inp['format'] for inp in obj.inputs] | inp_formats = [inp['format'] for inp in obj.inputs] | ||||
| @@ -122,10 +121,10 @@ class ExpanderInfoValidator: | |||||
| ','.join(inp_formats), obj.name)) | ','.join(inp_formats), obj.name)) | ||||
| def wrapper(*args, **kargs): | def wrapper(*args, **kargs): | ||||
| if not issubclass(cls, Expander): | |||||
| raise Exception("{} should be subclass of Expander.".format(cls.__name__)) | |||||
| ExpanderInfoValidator._add_check_function(cls, _check_format) | |||||
| return cls(*args, **kargs) | |||||
| if not issubclass(kls, Expander): | |||||
| raise Exception("{} should be subclass of Expander.".format(kls.__name__)) | |||||
| ExpanderInfoValidator._add_check_function(kls, _check_format) | |||||
| return kls(*args, **kargs) | |||||
| return wrapper | return wrapper | ||||
| @@ -23,7 +23,7 @@ class GeLU(Expander): | |||||
| def _expand(self, graph_builder): | def _expand(self, graph_builder): | ||||
| # cal formula are: | # cal formula are: | ||||
| # gelu(x) is 0.5 * x * (1.0 + tanh(y)) | |||||
| # gelu of x is 0.5 * x * (1.0 + tanh(y)) | |||||
| # y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | # y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | ||||
| input_x = self.inputs[0] | input_x = self.inputs[0] | ||||
| @@ -25,7 +25,7 @@ class GeLUGrad(Expander): | |||||
| def _expand(self, graph_builder): | def _expand(self, graph_builder): | ||||
| # cal formula are: | # cal formula are: | ||||
| # gelu_grad(dy, x) is dy * y' | |||||
| # gelu_grad of dy and x is dy * y' | |||||
| # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right | # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right | ||||
| # tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | # tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | ||||
| # mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x) | # mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x) | ||||
| @@ -22,7 +22,7 @@ class Sigmoid(Expander): | |||||
| def _expand(self, graph_builder): | def _expand(self, graph_builder): | ||||
| input_x = self.inputs[0] | input_x = self.inputs[0] | ||||
| # Calculate sigmoid(x) | # Calculate sigmoid(x) | ||||
| # formula is : sigmoid(x) = 1 / (1 + exp(-x)) | |||||
| # sigmoid of x is 1 / (1 + exp(-x)) | |||||
| const_one = graph_builder.value(input_x.dtype, 1.0) | const_one = graph_builder.value(input_x.dtype, 1.0) | ||||
| neg_x = graph_builder.emit('Neg', [input_x]) | neg_x = graph_builder.emit('Neg', [input_x]) | ||||
| exp_neg_x = graph_builder.emit('Exp', [neg_x]) | exp_neg_x = graph_builder.emit('Exp', [neg_x]) | ||||
| @@ -23,9 +23,8 @@ class SigmoidCrossEntropyWithLogits(Expander): | |||||
| def _expand(self, graph_builder): | def _expand(self, graph_builder): | ||||
| logits, label = self.inputs | logits, label = self.inputs | ||||
| # Calculate sigmoid_cross_entropy_with_logits(logits, label) | # Calculate sigmoid_cross_entropy_with_logits(logits, label) | ||||
| # formula is : | |||||
| # sigmoid_cross_entropy_with_logits(logits, label) | |||||
| # = -(label * log(sigmoid(logits)) + (1 - label) * log(1 - sigmoid(logits))) | |||||
| # formula of sigmoid_cross_entropy_with_logits is: | |||||
| # -(label * log(sigmoid(logits)) + (1 - label) * log(1 - sigmoid(logits))) | |||||
| const_one = graph_builder.value(logits.dtype, 1.0) | const_one = graph_builder.value(logits.dtype, 1.0) | ||||
| neg_x = graph_builder.emit('Neg', [logits]) | neg_x = graph_builder.emit('Neg', [logits]) | ||||
| exp_neg_x = graph_builder.emit('Exp', [neg_x]) | exp_neg_x = graph_builder.emit('Exp', [neg_x]) | ||||
| @@ -23,8 +23,8 @@ class SigmoidCrossEntropyWithLogitsGrad(Expander): | |||||
| def _expand(self, graph_builder): | def _expand(self, graph_builder): | ||||
| logits, label, dout = self.inputs | logits, label, dout = self.inputs | ||||
| # Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout) | # Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout) | ||||
| # formula is : | |||||
| # sigmoid_cross_entropy_with_logits_grad(logits, label, dout) = (sigmoid(logits) - label) * dout | |||||
| # formula of sigmoid_cross_entropy_with_logits_grad is : | |||||
| # (sigmoid(logits) - label) * dout | |||||
| const_one = graph_builder.value(logits.dtype, 1.0) | const_one = graph_builder.value(logits.dtype, 1.0) | ||||
| neg_x = graph_builder.emit('Neg', [logits]) | neg_x = graph_builder.emit('Neg', [logits]) | ||||
| exp_neg_x = graph_builder.emit('Exp', [neg_x]) | exp_neg_x = graph_builder.emit('Exp', [neg_x]) | ||||
| @@ -23,7 +23,7 @@ class SigmoidGrad(Expander): | |||||
| def _expand(self, graph_builder): | def _expand(self, graph_builder): | ||||
| input_y, dy = self.inputs | input_y, dy = self.inputs | ||||
| # Calculate sigmoid_grad(y, dy) | # Calculate sigmoid_grad(y, dy) | ||||
| # formula is : sigmoid_grad(y, dy) = (1 - y) * y * dy | |||||
| # formula of sigmoid_grad is : (1 - y) * y * dy | |||||
| const_one = graph_builder.value(input_y.dtype, 1.0) | const_one = graph_builder.value(input_y.dtype, 1.0) | ||||
| one_mins_y = graph_builder.emit('Sub', [const_one, input_y]) | one_mins_y = graph_builder.emit('Sub', [const_one, input_y]) | ||||
| y_mul_dy = graph_builder.emit('Mul', [input_y, dy]) | y_mul_dy = graph_builder.emit('Mul', [input_y, dy]) | ||||
| @@ -24,8 +24,7 @@ class SoftmaxCrossEntropyWithLogits(Expander): | |||||
| def _expand(self, graph_builder): | def _expand(self, graph_builder): | ||||
| logits, label = self.inputs | logits, label = self.inputs | ||||
| # Calculate softmax_cross_entropy_with_logits(logits, label) | # Calculate softmax_cross_entropy_with_logits(logits, label) | ||||
| # formula is : | |||||
| # softmax_cross_entropy_with_logits(logits, label) = -reduce_sum(label * log(softmax(logits))) | |||||
| # formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits))) | |||||
| axis = (-1,) | axis = (-1,) | ||||
| max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True}) | max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True}) | ||||
| data_sub = graph_builder.emit('Sub', [logits, max_x]) | data_sub = graph_builder.emit('Sub', [logits, max_x]) | ||||
| @@ -21,7 +21,7 @@ class SqrtGrad(Expander): | |||||
| """SqrtGrad expander""" | """SqrtGrad expander""" | ||||
| def _expand(self, graph_builder): | def _expand(self, graph_builder): | ||||
| # sqrt_grad(x, dout) = dout / (2 * x) | |||||
| # formula of sqrt_grad is dout / (2 * x) | |||||
| x, dout = self.inputs | x, dout = self.inputs | ||||
| const_two = graph_builder.value(x.dtype, 2) | const_two = graph_builder.value(x.dtype, 2) | ||||
| dividend = graph_builder.emit('Mul', [x, const_two]) | dividend = graph_builder.emit('Mul', [x, const_two]) | ||||