Browse Source

!15097 [graph kernel] clean code for expanders.

From: @chenlei_autodiff
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
pull/15097/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
ff75ce8ac4
10 changed files with 21 additions and 22 deletions
  1. +3
    -1
      mindspore/_extends/graph_kernel/expander.py
  2. +8
    -9
      mindspore/_extends/graph_kernel/expanders/_utils.py
  3. +1
    -1
      mindspore/_extends/graph_kernel/expanders/gelu.py
  4. +1
    -1
      mindspore/_extends/graph_kernel/expanders/gelu_grad.py
  5. +1
    -1
      mindspore/_extends/graph_kernel/expanders/sigmoid.py
  6. +2
    -3
      mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py
  7. +2
    -2
      mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py
  8. +1
    -1
      mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py
  9. +1
    -2
      mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py
  10. +1
    -1
      mindspore/_extends/graph_kernel/expanders/sqrt_grad.py

+ 3
- 1
mindspore/_extends/graph_kernel/expander.py View File

@@ -23,11 +23,13 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx

def create_expander(expand_info):
"""Create an expander according to op name"""
def call_func(func, arg):
return func(arg)
op_name = str(expand_info['name'])
if not hasattr(expanders, op_name):
raise GraphKernelUnsupportedException("Generator do not support op: {}".format(op_name))
expander = getattr(expanders, op_name)
return expander(expand_info)
return call_func(expander, expand_info)


def extract_expand_info(kernel_info):


+ 8
- 9
mindspore/_extends/graph_kernel/expanders/_utils.py View File

@@ -66,19 +66,18 @@ class Expander:

class ExpanderInfoValidator:
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
# pylint: disable=W0211
@staticmethod
def _add_check_function(cls, func):
def _add_check_function(kls, func):
"""
Rewrite the function `_check` in class Expander
to append the new `func` after the original checks.
"""
old_check = getattr(cls, "_check")
old_check = getattr(kls, "_check")

def new_check(obj):
old_check(obj)
func(obj)
setattr(cls, "_check", new_check)
setattr(kls, "_check", new_check)

@staticmethod
def add_format(*input_format):
@@ -112,7 +111,7 @@ class ExpanderInfoValidator:
return wrapper

@staticmethod
def check_all_formats_same(cls):
def check_all_formats_same(kls):
"""Check that all formats are the same"""
def _check_format(obj):
inp_formats = [inp['format'] for inp in obj.inputs]
@@ -122,10 +121,10 @@ class ExpanderInfoValidator:
','.join(inp_formats), obj.name))

def wrapper(*args, **kargs):
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
ExpanderInfoValidator._add_check_function(cls, _check_format)
return cls(*args, **kargs)
if not issubclass(kls, Expander):
raise Exception("{} should be subclass of Expander.".format(kls.__name__))
ExpanderInfoValidator._add_check_function(kls, _check_format)
return kls(*args, **kargs)

return wrapper



+ 1
- 1
mindspore/_extends/graph_kernel/expanders/gelu.py View File

@@ -23,7 +23,7 @@ class GeLU(Expander):

def _expand(self, graph_builder):
# cal formula are:
# gelu(x) is 0.5 * x * (1.0 + tanh(y))
# gelu of x is 0.5 * x * (1.0 + tanh(y))
# y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)

input_x = self.inputs[0]


+ 1
- 1
mindspore/_extends/graph_kernel/expanders/gelu_grad.py View File

@@ -25,7 +25,7 @@ class GeLUGrad(Expander):

def _expand(self, graph_builder):
# cal formula are:
# gelu_grad(dy, x) is dy * y'
# gelu_grad of dy and x is dy * y'
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
# tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
# mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)


+ 1
- 1
mindspore/_extends/graph_kernel/expanders/sigmoid.py View File

@@ -22,7 +22,7 @@ class Sigmoid(Expander):
def _expand(self, graph_builder):
input_x = self.inputs[0]
# Calculate sigmoid(x)
# formula is : sigmoid(x) = 1 / (1 + exp(-x))
# sigmoid of x is 1 / (1 + exp(-x))
const_one = graph_builder.value(input_x.dtype, 1.0)
neg_x = graph_builder.emit('Neg', [input_x])
exp_neg_x = graph_builder.emit('Exp', [neg_x])


+ 2
- 3
mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py View File

@@ -23,9 +23,8 @@ class SigmoidCrossEntropyWithLogits(Expander):
def _expand(self, graph_builder):
logits, label = self.inputs
# Calculate sigmoid_cross_entropy_with_logits(logits, label)
# formula is :
# sigmoid_cross_entropy_with_logits(logits, label)
# = -(label * log(sigmoid(logits)) + (1 - label) * log(1 - sigmoid(logits)))
# formula of sigmoid_cross_entropy_with_logits is:
# -(label * log(sigmoid(logits)) + (1 - label) * log(1 - sigmoid(logits)))
const_one = graph_builder.value(logits.dtype, 1.0)
neg_x = graph_builder.emit('Neg', [logits])
exp_neg_x = graph_builder.emit('Exp', [neg_x])


+ 2
- 2
mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py View File

@@ -23,8 +23,8 @@ class SigmoidCrossEntropyWithLogitsGrad(Expander):
def _expand(self, graph_builder):
logits, label, dout = self.inputs
# Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout)
# formula is :
# sigmoid_cross_entropy_with_logits_grad(logits, label, dout) = (sigmoid(logits) - label) * dout
# formula of sigmoid_cross_entropy_with_logits_grad is :
# (sigmoid(logits) - label) * dout
const_one = graph_builder.value(logits.dtype, 1.0)
neg_x = graph_builder.emit('Neg', [logits])
exp_neg_x = graph_builder.emit('Exp', [neg_x])


+ 1
- 1
mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py View File

@@ -23,7 +23,7 @@ class SigmoidGrad(Expander):
def _expand(self, graph_builder):
input_y, dy = self.inputs
# Calculate sigmoid_grad(y, dy)
# formula is : sigmoid_grad(y, dy) = (1 - y) * y * dy
# formula of sigmoid_grad is : (1 - y) * y * dy
const_one = graph_builder.value(input_y.dtype, 1.0)
one_mins_y = graph_builder.emit('Sub', [const_one, input_y])
y_mul_dy = graph_builder.emit('Mul', [input_y, dy])


+ 1
- 2
mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py View File

@@ -24,8 +24,7 @@ class SoftmaxCrossEntropyWithLogits(Expander):
def _expand(self, graph_builder):
logits, label = self.inputs
# Calculate softmax_cross_entropy_with_logits(logits, label)
# formula is :
# softmax_cross_entropy_with_logits(logits, label) = -reduce_sum(label * log(softmax(logits)))
# formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits)))
axis = (-1,)
max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True})
data_sub = graph_builder.emit('Sub', [logits, max_x])


+ 1
- 1
mindspore/_extends/graph_kernel/expanders/sqrt_grad.py View File

@@ -21,7 +21,7 @@ class SqrtGrad(Expander):
"""SqrtGrad expander"""

def _expand(self, graph_builder):
# sqrt_grad(x, dout) = dout / (2 * x)
# formula of sqrt_grad is dout / (2 * x)
x, dout = self.inputs
const_two = graph_builder.value(x.dtype, 2)
dividend = graph_builder.emit('Mul', [x, const_two])


Loading…
Cancel
Save