diff --git a/mindspore/_extends/graph_kernel/expander.py b/mindspore/_extends/graph_kernel/expander.py index 115e1f0617..9ea3f115c8 100644 --- a/mindspore/_extends/graph_kernel/expander.py +++ b/mindspore/_extends/graph_kernel/expander.py @@ -23,11 +23,13 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx def create_expander(expand_info): """Create an expander according to op name""" + def call_func(func, arg): + return func(arg) op_name = str(expand_info['name']) if not hasattr(expanders, op_name): raise GraphKernelUnsupportedException("Generator do not support op: {}".format(op_name)) expander = getattr(expanders, op_name) - return expander(expand_info) + return call_func(expander, expand_info) def extract_expand_info(kernel_info): diff --git a/mindspore/_extends/graph_kernel/expanders/_utils.py b/mindspore/_extends/graph_kernel/expanders/_utils.py index 803f7ade8e..52753be0c7 100644 --- a/mindspore/_extends/graph_kernel/expanders/_utils.py +++ b/mindspore/_extends/graph_kernel/expanders/_utils.py @@ -66,19 +66,18 @@ class Expander: class ExpanderInfoValidator: """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders""" - # pylint: disable=W0211 @staticmethod - def _add_check_function(cls, func): + def _add_check_function(kls, func): """ Rewrite the function `_check` in class Expander to append the new `func` after the original checks. """ - old_check = getattr(cls, "_check") + old_check = getattr(kls, "_check") def new_check(obj): old_check(obj) func(obj) - setattr(cls, "_check", new_check) + setattr(kls, "_check", new_check) @staticmethod def add_format(*input_format): @@ -112,7 +111,7 @@ class ExpanderInfoValidator: return wrapper @staticmethod - def check_all_formats_same(cls): + def check_all_formats_same(kls): """Check that all formats are the same""" def _check_format(obj): inp_formats = [inp['format'] for inp in obj.inputs] @@ -122,10 +121,10 @@ class ExpanderInfoValidator: ','.join(inp_formats), obj.name)) def wrapper(*args, **kargs): - if not issubclass(cls, Expander): - raise Exception("{} should be subclass of Expander.".format(cls.__name__)) - ExpanderInfoValidator._add_check_function(cls, _check_format) - return cls(*args, **kargs) + if not issubclass(kls, Expander): + raise Exception("{} should be subclass of Expander.".format(kls.__name__)) + ExpanderInfoValidator._add_check_function(kls, _check_format) + return kls(*args, **kargs) return wrapper diff --git a/mindspore/_extends/graph_kernel/expanders/gelu.py b/mindspore/_extends/graph_kernel/expanders/gelu.py index 8242413c4d..5c2f1c4ee2 100644 --- a/mindspore/_extends/graph_kernel/expanders/gelu.py +++ b/mindspore/_extends/graph_kernel/expanders/gelu.py @@ -23,7 +23,7 @@ class GeLU(Expander): def _expand(self, graph_builder): # cal formula are: - # gelu(x) is 0.5 * x * (1.0 + tanh(y)) + # gelu of x is 0.5 * x * (1.0 + tanh(y)) # y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) input_x = self.inputs[0] diff --git a/mindspore/_extends/graph_kernel/expanders/gelu_grad.py b/mindspore/_extends/graph_kernel/expanders/gelu_grad.py index cec69d6719..dd00c1f5f3 100644 --- a/mindspore/_extends/graph_kernel/expanders/gelu_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/gelu_grad.py @@ -25,7 +25,7 @@ class GeLUGrad(Expander): def _expand(self, graph_builder): # cal formula are: - # gelu_grad(dy, x) is dy * y' + # gelu_grad of dy and x is dy * y' # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right # tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) # mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x) diff --git a/mindspore/_extends/graph_kernel/expanders/sigmoid.py b/mindspore/_extends/graph_kernel/expanders/sigmoid.py index 3202d9f9f2..8f18d2c92e 100644 --- a/mindspore/_extends/graph_kernel/expanders/sigmoid.py +++ b/mindspore/_extends/graph_kernel/expanders/sigmoid.py @@ -22,7 +22,7 @@ class Sigmoid(Expander): def _expand(self, graph_builder): input_x = self.inputs[0] # Calculate sigmoid(x) - # formula is : sigmoid(x) = 1 / (1 + exp(-x)) + # sigmoid of x is 1 / (1 + exp(-x)) const_one = graph_builder.value(input_x.dtype, 1.0) neg_x = graph_builder.emit('Neg', [input_x]) exp_neg_x = graph_builder.emit('Exp', [neg_x]) diff --git a/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py b/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py index abe1a00b9f..2d5d4f5614 100644 --- a/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +++ b/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py @@ -23,9 +23,8 @@ class SigmoidCrossEntropyWithLogits(Expander): def _expand(self, graph_builder): logits, label = self.inputs # Calculate sigmoid_cross_entropy_with_logits(logits, label) - # formula is : - # sigmoid_cross_entropy_with_logits(logits, label) - # = -(label * log(sigmoid(logits)) + (1 - label) * log(1 - sigmoid(logits))) + # formula of sigmoid_cross_entropy_with_logits is: + # -(label * log(sigmoid(logits)) + (1 - label) * log(1 - sigmoid(logits))) const_one = graph_builder.value(logits.dtype, 1.0) neg_x = graph_builder.emit('Neg', [logits]) exp_neg_x = graph_builder.emit('Exp', [neg_x]) diff --git a/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py b/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py index e853ce8d1b..63c1f3936f 100644 --- a/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py @@ -23,8 +23,8 @@ class SigmoidCrossEntropyWithLogitsGrad(Expander): def _expand(self, graph_builder): logits, label, dout = self.inputs # Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout) - # formula is : - # sigmoid_cross_entropy_with_logits_grad(logits, label, dout) = (sigmoid(logits) - label) * dout + # formula of sigmoid_cross_entropy_with_logits_grad is : + # (sigmoid(logits) - label) * dout const_one = graph_builder.value(logits.dtype, 1.0) neg_x = graph_builder.emit('Neg', [logits]) exp_neg_x = graph_builder.emit('Exp', [neg_x]) diff --git a/mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py b/mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py index 90b63ffce3..ad73567cdc 100644 --- a/mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py @@ -23,7 +23,7 @@ class SigmoidGrad(Expander): def _expand(self, graph_builder): input_y, dy = self.inputs # Calculate sigmoid_grad(y, dy) - # formula is : sigmoid_grad(y, dy) = (1 - y) * y * dy + # formula of sigmoid_grad is : (1 - y) * y * dy const_one = graph_builder.value(input_y.dtype, 1.0) one_mins_y = graph_builder.emit('Sub', [const_one, input_y]) y_mul_dy = graph_builder.emit('Mul', [input_y, dy]) diff --git a/mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py b/mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py index b9f4de6a1a..f2e2afb2b5 100644 --- a/mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +++ b/mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py @@ -24,8 +24,7 @@ class SoftmaxCrossEntropyWithLogits(Expander): def _expand(self, graph_builder): logits, label = self.inputs # Calculate softmax_cross_entropy_with_logits(logits, label) - # formula is : - # softmax_cross_entropy_with_logits(logits, label) = -reduce_sum(label * log(softmax(logits))) + # formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits))) axis = (-1,) max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True}) data_sub = graph_builder.emit('Sub', [logits, max_x]) diff --git a/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py b/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py index aaea624114..9f072e7a87 100644 --- a/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py @@ -21,7 +21,7 @@ class SqrtGrad(Expander): """SqrtGrad expander""" def _expand(self, graph_builder): - # sqrt_grad(x, dout) = dout / (2 * x) + # formula of sqrt_grad is dout / (2 * x) x, dout = self.inputs const_two = graph_builder.value(x.dtype, 2) dividend = graph_builder.emit('Mul', [x, const_two])