From: @dayschan Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -18,6 +18,16 @@ import json.decoder as jd | |||
| import traceback | |||
| from mindspore import log as logger | |||
| import mindspore._extends.graph_kernel.expanders as expanders | |||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException | |||
| def create_expander(expand_info): | |||
| """Create an expander according to op name""" | |||
| op_name = str(expand_info['name']) | |||
| if not hasattr(expanders, op_name): | |||
| raise GraphKernelUnsupportedException("Generator do not support op: {}".format(op_name)) | |||
| expander = getattr(expanders, op_name) | |||
| return expander(expand_info) | |||
| def extract_expand_info(kernel_info): | |||
| @@ -46,20 +56,8 @@ def get_op_expander(json_str: str): | |||
| kernel_info = json.loads(json_str) | |||
| expand_info = extract_expand_info(kernel_info) | |||
| processor = expand_info['process'] | |||
| op_name = str(expand_info['name']).lower() | |||
| expand_op_func_name = 'expand_' + op_name | |||
| if not hasattr(expanders, expand_op_func_name): | |||
| logger.error("Generator do not support op: {}".format(op_name)) | |||
| return None | |||
| expand_op_func = getattr(expanders, expand_op_func_name) | |||
| # generate graph desc. | |||
| graph = expand_op_func(expand_info) | |||
| if graph is None: | |||
| logger.error("Failed to generate graph of: {}".format(op_name)) | |||
| return None | |||
| graph.set_processor(processor) | |||
| expander = create_expander(expand_info) | |||
| graph = expander.run() | |||
| # dump graph to json desc. | |||
| desc = graph.dump() | |||
| @@ -69,3 +67,6 @@ def get_op_expander(json_str: str): | |||
| logger.error("Failed to generate graph kernel op") | |||
| logger.error(traceback.format_exc()) | |||
| return None | |||
| except GraphKernelUnsupportedException as e: | |||
| logger.info(e.message) | |||
| return "" | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,24 +14,24 @@ | |||
| # ============================================================================ | |||
| """expanders init""" | |||
| from .gelu import expand_gelu | |||
| from .gelu_grad import expand_gelugrad | |||
| from .layernorm import expand_layernorm | |||
| from .softmax import expand_softmax | |||
| from .square import expand_square | |||
| from .bias_add import expand_biasadd | |||
| from .bias_add_grad import expand_biasaddgrad | |||
| from .fused_adam import expand_fusedadam | |||
| from .fused_adam_weight_decay import expand_fusedadamweightdecay | |||
| from .reduce_mean import expand_reducemean | |||
| from .tanh_grad import expand_tanhgrad | |||
| from .maximum_grad import expand_maximumgrad | |||
| from .minimum_grad import expand_minimumgrad | |||
| from .dropout_grad import expand_dropoutgrad | |||
| from .layernorm_grad import expand_layernormgrad | |||
| from .logsoftmax import expand_logsoftmax | |||
| from .logsoftmax_grad import expand_logsoftmaxgrad | |||
| from .gkdropout import expand_gkdropout | |||
| from .tile import expand_tile | |||
| from .sqrt_grad import expand_sqrtgrad | |||
| from .clip_by_norm_no_div_sum import expand_clipbynormnodivsum | |||
| from .bias_add import BiasAdd | |||
| from .bias_add_grad import BiasAddGrad | |||
| from .clip_by_norm_no_div_sum import ClipByNormNoDivSum | |||
| from .dropout_grad import DropoutGrad | |||
| from .fused_adam import FusedAdam | |||
| from .fused_adam_weight_decay import FusedAdamWeightDecay | |||
| from .gelu import GeLU | |||
| from .gelu_grad import GeLUGrad | |||
| from .gkdropout import GkDropout | |||
| from .layernorm import LayerNorm | |||
| from .layernorm_grad import LayerNormGrad | |||
| from .logsoftmax import LogSoftmax | |||
| from .logsoftmax_grad import LogSoftmaxGrad | |||
| from .maximum_grad import MaximumGrad | |||
| from .minimum_grad import MinimumGrad | |||
| from .reduce_mean import ReduceMean | |||
| from .softmax import Softmax | |||
| from .sqrt_grad import SqrtGrad | |||
| from .square import Square | |||
| from .tanh_grad import TanhGrad | |||
| from .tile import Tile | |||
| @@ -0,0 +1,146 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """GraphKernel expander utils""" | |||
| from abc import ABCMeta, abstractmethod | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException | |||
| class Expander: | |||
| """ | |||
| Expander is the base class of expanders. | |||
| The method `_expand` should be overridden to implement the operator detail. | |||
| """ | |||
| __metaclass__ = ABCMeta | |||
| def __init__(self, expand_info): | |||
| self.name = expand_info["name"] | |||
| self.inputs = expand_info["input_desc"] | |||
| self.outputs = expand_info["output_desc"] | |||
| self.attrs = expand_info["attr"] | |||
| self.processor = expand_info["process"] | |||
| def run(self): | |||
| """ | |||
| Expand the operator to a graph. | |||
| `GraphKernelUnsupportedException` would be raised if check failed. | |||
| """ | |||
| self._check() | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope(self.name) as graph_scope: | |||
| # transform input_desc to Tensor | |||
| self.inputs = [graph_builder.tensor(inp['shape'], inp['data_type'], inp['format']) for inp in self.inputs] | |||
| graph_scope.set_input(*self.inputs) | |||
| outputs = self._expand(graph_builder) | |||
| if isinstance(outputs, (list, tuple)): | |||
| graph_scope.set_output(*outputs) | |||
| else: | |||
| graph_scope.set_output(outputs) | |||
| graph = graph_builder.get()[0] | |||
| graph.set_processor(self.processor) | |||
| return graph | |||
| def _check(self): | |||
| """Check inputs""" | |||
| @abstractmethod | |||
| def _expand(self, graph_builder): | |||
| """Expand operator, this function should be overridden in subclass""" | |||
| raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__)) | |||
| class ExpanderInfoValidator: | |||
| """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders""" | |||
| # pylint: disable=W0211 | |||
| @staticmethod | |||
| def _add_check_function(cls, func): | |||
| """ | |||
| Rewrite the function `_check` in class Expander | |||
| to append the new `func` after the original checks. | |||
| """ | |||
| old_check = getattr(cls, "_check") | |||
| def new_check(obj): | |||
| old_check(obj) | |||
| func(obj) | |||
| setattr(cls, "_check", new_check) | |||
| @staticmethod | |||
| def add_format(*input_format): | |||
| """ | |||
| Add new supported format for the operator | |||
| this function will add a list `__supported_formats` into the expander, | |||
| saving the whitelist of formats that this op supports. | |||
| it also rewrites the `_check` function to check the formats. | |||
| """ | |||
| format_list_name = "__supported_formats" | |||
| def _check_format(obj): | |||
| inp_formats = [inp['format'] for inp in obj.inputs] | |||
| for formats in getattr(obj, format_list_name): | |||
| if len(formats) != len(inp_formats): | |||
| raise GKException("length of registered format doesn't match with the input of {}".format(obj.name)) | |||
| if all([fmt == inp for fmt, inp in zip(formats, inp_formats)]): | |||
| return | |||
| raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name)) | |||
| def wrapper(cls): | |||
| if not issubclass(cls, Expander): | |||
| raise Exception("{} should be subclass of Expander.".format(cls.__name__)) | |||
| if not hasattr(cls, format_list_name): | |||
| setattr(cls, format_list_name, list()) | |||
| ExpanderInfoValidator._add_check_function(cls, _check_format) | |||
| getattr(cls, format_list_name).append(input_format) | |||
| return cls | |||
| return wrapper | |||
| @staticmethod | |||
| def check_all_formats_same(cls): | |||
| """Check that all formats are the same""" | |||
| def _check_format(obj): | |||
| inp_formats = [inp['format'] for inp in obj.inputs] | |||
| if all([fmt == inp_formats[0] for fmt in inp_formats[1:]]): | |||
| return | |||
| raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format( | |||
| ','.join(inp_formats), obj.name)) | |||
| def wrapper(*args, **kargs): | |||
| if not issubclass(cls, Expander): | |||
| raise Exception("{} should be subclass of Expander.".format(cls.__name__)) | |||
| ExpanderInfoValidator._add_check_function(cls, _check_format) | |||
| return cls(*args, **kargs) | |||
| return wrapper | |||
| @staticmethod | |||
| def check_attrs(*args): | |||
| """Check the attrs exist""" | |||
| def _check_attr(obj): | |||
| for a in args: | |||
| if a not in obj.attrs: | |||
| raise GKException("attr '{}' does not exist.".format(a)) | |||
| def wrapper(cls): | |||
| if not issubclass(cls, Expander): | |||
| raise Exception("{} should be subclass of Expander.".format(cls.__name__)) | |||
| ExpanderInfoValidator._add_check_function(cls, _check_attr) | |||
| return cls | |||
| return wrapper | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,50 +13,34 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for bias_add""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_biasadd(expand_info): | |||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT) | |||
| @VLD.add_format(DF.NCHW, DF.DEFAULT) | |||
| @VLD.add_format(DF.NHWC, DF.DEFAULT) | |||
| class BiasAdd(Expander): | |||
| """BiasAdd expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor( | |||
| input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_y = graph_builder.tensor( | |||
| input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| graph_scope.set_input(input_x, input_y) | |||
| if input_x.data_format == "NCHW": | |||
| input_y_expand = graph_builder.emit( | |||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | |||
| input_y_expand = graph_builder.emit( | |||
| 'ExpandDims', [input_y_expand], attrs={'axis': 2}) | |||
| def _expand(self, graph_builder): | |||
| input_x, input_y = self.inputs | |||
| if input_x.data_format == DF.NCHW: | |||
| input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1}) | |||
| input_y_expand = graph_builder.emit('ExpandDims', [input_y_expand], attrs={'axis': 2}) | |||
| result = graph_builder.emit('Add', [input_x, input_y_expand]) | |||
| elif input_x.data_format == "DefaultFormat": | |||
| elif input_x.data_format == DF.DEFAULT: | |||
| if len(input_x.shape) == 2: | |||
| result = graph_builder.emit('Add', [input_x, input_y]) | |||
| elif len(input_x.shape) == 3: | |||
| input_y_expand = graph_builder.emit( | |||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | |||
| result = graph_builder.emit( | |||
| 'Add', [input_x, input_y_expand]) | |||
| else: | |||
| input_y_expand = graph_builder.emit( | |||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | |||
| input_y_expand = graph_builder.emit( | |||
| 'ExpandDims', [input_y_expand], attrs={'axis': 2}) | |||
| result = graph_builder.emit( | |||
| 'Add', [input_x, input_y_expand]) | |||
| else: | |||
| input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1}) | |||
| result = graph_builder.emit('Add', [input_x, input_y_expand]) | |||
| else: # len == 4 | |||
| input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1}) | |||
| input_y_expand = graph_builder.emit('ExpandDims', [input_y_expand], attrs={'axis': 2}) | |||
| result = graph_builder.emit('Add', [input_x, input_y_expand]) | |||
| else: # NHWC | |||
| result = graph_builder.emit('Add', [input_x, input_y]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,26 +13,25 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for bias_add""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_biasaddgrad(expand_info): | |||
| @VLD.add_format(DF.DEFAULT) | |||
| @VLD.add_format(DF.NHWC) | |||
| @VLD.add_format(DF.NCHW) | |||
| class BiasAddGrad(Expander): | |||
| """BiasAddGrad expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor( | |||
| input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| graph_scope.set_input(input_x) | |||
| def _expand(self, graph_builder): | |||
| input_x = self.inputs[0] | |||
| reduce_axis = () | |||
| if input_x.data_format == 'NHWC': | |||
| reduce_axis = (0, 1, 2) | |||
| elif input_x.data_format == 'NCHW': | |||
| reduce_axis = (0, 2, 3) | |||
| # Default format shape's length maybe equal 2 to 4, so different shape's length reduce axis are differnet | |||
| # DefaultFormat shape's length should be from 2 to 4 | |||
| else: | |||
| if len(input_x.shape) == 2: | |||
| reduce_axis = (0,) | |||
| @@ -41,8 +40,4 @@ def expand_biasaddgrad(expand_info): | |||
| else: | |||
| reduce_axis = (0, 2, 3) | |||
| result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,27 +13,15 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for ClipByNormNoDivSum""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_clipbynormnodivsum(expand_info): | |||
| @VLD.check_all_formats_same | |||
| class ClipByNormNoDivSum(Expander): | |||
| """ClipByNormNoDivSum expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| input_desc_3 = expand_info['input_desc'][3] | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x0 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_x1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| input_x2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||
| input_x3 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) | |||
| graph_scope.set_input(input_x0, input_x1, input_x2, input_x3) | |||
| def _expand(self, graph_builder): | |||
| input_x0, input_x1, input_x2, input_x3 = self.inputs | |||
| # cal result | |||
| greater_res = graph_builder.emit('Greater', [input_x0, input_x1], attrs={'fusion': 'SelectGT_000'}) | |||
| @@ -44,8 +32,4 @@ def expand_clipbynormnodivsum(expand_info): | |||
| attrs={'fusion': 'SelectGT_000_end'}) | |||
| result = graph_builder.emit('Maximum', [select_res1, input_x3]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -13,27 +13,18 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for DropoutGrad""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_dropoutgrad(expand_info): | |||
| @VLD.check_all_formats_same | |||
| @VLD.check_attrs('keep_prob') | |||
| class DropoutGrad(Expander): | |||
| """DropoutGrad expander""" | |||
| # get op info. | |||
| dy_desc = expand_info['input_desc'][0] | |||
| mask_desc = expand_info['input_desc'][1] | |||
| keep_prob = expand_info['attr']['keep_prob'] | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) | |||
| input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format']) | |||
| graph_scope.set_input(input_dy, input_mask) | |||
| def _expand(self, graph_builder): | |||
| input_dy, input_mask = self.inputs | |||
| keep_prob = self.attrs['keep_prob'] | |||
| r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob) | |||
| # create op. | |||
| result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) | |||
| result = graph_builder.emit('Mul', [result, input_mask]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,40 +13,16 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for fused_adam""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_fusedadam(expand_info): | |||
| """FusedAdma expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| input_desc_3 = expand_info['input_desc'][3] | |||
| input_desc_4 = expand_info['input_desc'][4] | |||
| input_desc_5 = expand_info['input_desc'][5] | |||
| input_desc_6 = expand_info['input_desc'][6] | |||
| input_desc_7 = expand_info['input_desc'][7] | |||
| input_desc_8 = expand_info['input_desc'][8] | |||
| input_desc_9 = expand_info['input_desc'][9] | |||
| graph_builder = builder.GraphBuilder() | |||
| @VLD.check_all_formats_same | |||
| class FusedAdam(Expander): | |||
| """FusedAdam expander""" | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||
| one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) | |||
| eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format']) | |||
| lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format']) | |||
| param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format']) | |||
| m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format']) | |||
| v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format']) | |||
| gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format']) | |||
| graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient) | |||
| def _expand(self, graph_builder): | |||
| beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs | |||
| # compute result | |||
| beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | |||
| one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | |||
| next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | |||
| @@ -60,12 +36,9 @@ def expand_fusedadam(expand_info): | |||
| update_with_lr = graph_builder.emit('Mul', [lr, update]) | |||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | |||
| param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||
| param_result = graph_builder.emit( | |||
| 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||
| param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True}) | |||
| param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True}) | |||
| # set graph output. | |||
| graph_scope.set_output(param_result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return param_result | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,41 +13,15 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for fused_adam_weight_decay""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_fusedadamweightdecay(expand_info): | |||
| """FusedAdmaWeightDecay expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| input_desc_3 = expand_info['input_desc'][3] | |||
| input_desc_4 = expand_info['input_desc'][4] | |||
| input_desc_5 = expand_info['input_desc'][5] | |||
| input_desc_6 = expand_info['input_desc'][6] | |||
| input_desc_7 = expand_info['input_desc'][7] | |||
| input_desc_8 = expand_info['input_desc'][8] | |||
| input_desc_9 = expand_info['input_desc'][9] | |||
| input_desc_10 = expand_info['input_desc'][10] | |||
| graph_builder = builder.GraphBuilder() | |||
| @VLD.check_all_formats_same | |||
| class FusedAdamWeightDecay(Expander): | |||
| """FusedAdamWeightDecay expander""" | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||
| one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) | |||
| eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format']) | |||
| lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format']) | |||
| param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format']) | |||
| m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format']) | |||
| v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format']) | |||
| gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format']) | |||
| weight_decay = graph_builder.tensor(input_desc_10['shape'], input_desc_10['data_type'], input_desc_10['format']) | |||
| graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, | |||
| eps, lr, param, m, v, gradient, weight_decay) | |||
| def _expand(self, graph_builder): | |||
| beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs | |||
| # compute result | |||
| beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | |||
| @@ -65,12 +39,9 @@ def expand_fusedadamweightdecay(expand_info): | |||
| update_with_lr = graph_builder.emit('Mul', [lr, update]) | |||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | |||
| para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||
| para_result = graph_builder.emit( | |||
| 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||
| para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True}) | |||
| para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True}) | |||
| # set graph output. | |||
| graph_scope.set_output(para_result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return para_result | |||
| @@ -13,49 +13,36 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for gelu""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from ._utils import Expander | |||
| CSVALUE = 0.044715 | |||
| CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) | |||
| ONE = 1.0 | |||
| HALF = 0.5 | |||
| def expand_gelu(expand_info): | |||
| class GeLU(Expander): | |||
| """GeLU expander""" | |||
| # cal formula are: | |||
| # gelu(x) is 0.5 * x * (1.0 + tanh(y)) | |||
| # y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | |||
| CSVALUE = 0.044715 | |||
| CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| graph_builder = builder.GraphBuilder() | |||
| def _expand(self, graph_builder): | |||
| # cal formula are: | |||
| # gelu(x) is 0.5 * x * (1.0 + tanh(y)) | |||
| # y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| graph_scope.set_input(input_x) | |||
| input_x = self.inputs[0] | |||
| # cal y | |||
| mul_0 = graph_builder.emit('Mul', [input_x, input_x]) | |||
| pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) | |||
| const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE) | |||
| const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE) | |||
| mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) | |||
| tanh_res = graph_builder.emit('Add', [input_x, mul_1]) | |||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI) | |||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI) | |||
| y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) | |||
| # cal gelu(x) | |||
| tanh_y = graph_builder.emit('Tanh', [y]) | |||
| const_one = graph_builder.value(tanh_y.dtype, ONE) | |||
| const_half = graph_builder.value(tanh_y.dtype, HALF) | |||
| const_one = graph_builder.value(tanh_y.dtype, 1) | |||
| const_half = graph_builder.value(tanh_y.dtype, 0.5) | |||
| tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) | |||
| mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | |||
| result = graph_builder.emit('Mul', [const_half, mul_x]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -13,43 +13,31 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for gelugrad""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| CSVALUE = 0.044715 | |||
| CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) | |||
| CSVALUE_TRI = 0.134141 # CSVALUE * 3 | |||
| ONE = 1.0 | |||
| HALF = 0.5 | |||
| def expand_gelugrad(expand_info): | |||
| @VLD.check_all_formats_same | |||
| class GeLUGrad(Expander): | |||
| """GeLUGrad expander""" | |||
| # cal formula are: | |||
| # gelu_grad(dy, x) is dy * y' | |||
| # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right | |||
| # tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | |||
| # mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x) | |||
| CSVALUE = 0.044715 | |||
| CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) | |||
| CSVALUE_TRI = 0.134141 # CSVALUE * 3 | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| graph_builder = builder.GraphBuilder() | |||
| def _expand(self, graph_builder): | |||
| # cal formula are: | |||
| # gelu_grad(dy, x) is dy * y' | |||
| # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right | |||
| # tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | |||
| # mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x) | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_dy = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||
| graph_scope.set_input(input_dy, input_x, input_y) | |||
| input_dy, input_x, _ = self.inputs | |||
| # create some const var | |||
| const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE) | |||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI) | |||
| const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI) | |||
| const_one = graph_builder.value(input_dy.dtype, ONE) | |||
| const_half = graph_builder.value(input_dy.dtype, HALF) | |||
| const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE) | |||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI) | |||
| const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI) | |||
| const_one = graph_builder.value(input_dy.dtype, 1) | |||
| const_half = graph_builder.value(input_dy.dtype, 0.5) | |||
| # cal mul_right | |||
| mul_double = graph_builder.emit('Mul', [input_x, input_x]) | |||
| @@ -79,8 +67,4 @@ def expand_gelugrad(expand_info): | |||
| result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final]) | |||
| result = graph_builder.emit('Mul', [input_dy, result_tmp]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -12,35 +12,29 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for GkDropOut""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| """generate json desc for GkDropout""" | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_gkdropout(expand_info): | |||
| """GkDropOut expander""" | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| maks_desc = expand_info['input_desc'][1] | |||
| keep_prob = expand_info['attr']['keep_prob'] | |||
| @VLD.check_all_formats_same | |||
| @VLD.check_attrs('keep_prob') | |||
| class GkDropout(Expander): | |||
| """GkDropout expander""" | |||
| def _expand(self, graph_builder): | |||
| input_x, input_mask = self.inputs | |||
| keep_prob = self.attrs['keep_prob'] | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format']) | |||
| graph_scope.set_input(input_x, input_mask) | |||
| keep_prob_v = graph_builder.value(input_x.dtype, keep_prob) | |||
| r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob) | |||
| keep_prob = graph_builder.value(input_x.dtype, keep_prob) | |||
| if input_mask.dtype != input_x.dtype: | |||
| input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype}) | |||
| mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type | |||
| mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type | |||
| mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype}) | |||
| # compute result | |||
| result = graph_builder.emit('Mul', [r_keep_prob, input_x]) | |||
| result = graph_builder.emit('Mul', [result, mask]) | |||
| # set graph output. | |||
| graph_scope.set_output(result, mask) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result, mask | |||
| @@ -13,38 +13,31 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for LayerNorm""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_layernorm(expand_info): | |||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||
| @VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon') | |||
| class LayerNorm(Expander): | |||
| """LayerNorm expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| attrs = expand_info['attr'] | |||
| begin_norm_axis = attrs['begin_norm_axis'] | |||
| epsilon = attrs['epsilon'] | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_gamma = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| input_beta = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||
| def _expand(self, graph_builder): | |||
| input_x, input_gamma, input_beta = self.inputs | |||
| begin_norm_axis = self.attrs['begin_norm_axis'] | |||
| epsilon = self.attrs['epsilon'] | |||
| # Calculate the scaling ratio of the average | |||
| shape_x = input_desc_0['shape'] | |||
| if begin_norm_axis < 0: | |||
| begin_norm_axis += len(shape_x) | |||
| begin_norm_axis += len(input_x.shape) | |||
| reduce_axis = () | |||
| for i, _ in enumerate(shape_x): | |||
| for i, _ in enumerate(input_x.shape): | |||
| if i > begin_norm_axis or i == begin_norm_axis: | |||
| reduce_axis = reduce_axis + (i,) | |||
| reduce_elts = 1.0 | |||
| for i in reduce_axis: | |||
| reduce_elts *= shape_x[i] | |||
| reduce_elts *= input_x.shape[i] | |||
| mean_cof = 1.0 / reduce_elts | |||
| mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) | |||
| @@ -70,8 +63,4 @@ def expand_layernorm(expand_info): | |||
| scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul]) | |||
| res = graph_builder.emit('Add', [scale_mul, input_beta]) | |||
| # set graph output. | |||
| graph_scope.set_output(res, mean, variance) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return res, mean, variance | |||
| @@ -13,42 +13,30 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for LayerNormGrad""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_layernormgrad(expand_info): | |||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||
| @VLD.check_attrs('begin_norm_axis', 'begin_params_axis') | |||
| class LayerNormGrad(Expander): | |||
| """LayerNormGrad expander""" | |||
| # get op info. | |||
| x_desc = expand_info['input_desc'][0] | |||
| dy_desc = expand_info['input_desc'][1] | |||
| var_desc = expand_info['input_desc'][2] | |||
| mean_desc = expand_info['input_desc'][3] | |||
| gamma_desc = expand_info['input_desc'][4] | |||
| attrs = expand_info['attr'] | |||
| begin_norm_axis = attrs['begin_norm_axis'] | |||
| begin_params_axis = attrs['begin_params_axis'] | |||
| epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-11 | |||
| shape_x = x_desc['shape'] | |||
| if begin_norm_axis < 0: | |||
| begin_norm_axis += len(shape_x) | |||
| if begin_params_axis < 0: | |||
| begin_params_axis += len(shape_x) | |||
| norm_axis = tuple(range(begin_norm_axis, len(shape_x))) | |||
| param_axis = tuple(range(0, begin_params_axis)) | |||
| reduce_size = 1.0 | |||
| for i in norm_axis: | |||
| reduce_size *= shape_x[i] | |||
| def _expand(self, graph_builder): | |||
| x, dy, variance, mean, gamma = self.inputs | |||
| begin_norm_axis = self.attrs['begin_norm_axis'] | |||
| begin_params_axis = self.attrs['begin_params_axis'] | |||
| epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11 | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create input tensors. | |||
| x = graph_builder.tensor(x_desc['shape'], x_desc['data_type'], x_desc['format']) | |||
| dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) | |||
| variance = graph_builder.tensor(var_desc['shape'], var_desc['data_type'], var_desc['format']) | |||
| mean = graph_builder.tensor(mean_desc['shape'], mean_desc['data_type'], mean_desc['format']) | |||
| gamma = graph_builder.tensor(gamma_desc['shape'], gamma_desc['data_type'], gamma_desc['format']) | |||
| graph_scope.set_input(x, dy, variance, mean, gamma) | |||
| if begin_norm_axis < 0: | |||
| begin_norm_axis += len(x.shape) | |||
| if begin_params_axis < 0: | |||
| begin_params_axis += len(x.shape) | |||
| norm_axis = tuple(range(begin_norm_axis, len(x.shape))) | |||
| param_axis = tuple(range(0, begin_params_axis)) | |||
| reduce_size = 1.0 | |||
| for i in norm_axis: | |||
| reduce_size *= x.shape[i] | |||
| # set some constant val. | |||
| eps = graph_builder.value(x.dtype, epsilon) | |||
| @@ -99,8 +87,4 @@ def expand_layernormgrad(expand_info): | |||
| dx_tmp = graph_builder.emit('Add', [dx_1, dx_2]) | |||
| dx = graph_builder.emit('Add', [dx_tmp, dx_3]) | |||
| # set graph output. | |||
| graph_scope.set_output(dx, dg, db) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return dx, dg, db | |||
| @@ -13,24 +13,21 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for LogSoftmax""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_logsoftmax(expand_info): | |||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT) | |||
| @VLD.check_attrs('axis') | |||
| class LogSoftmax(Expander): | |||
| """LogSoftmax expander""" | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| axis = expand_info['attr']['axis'] | |||
| graph_builder = builder.GraphBuilder() | |||
| if isinstance(axis, int): | |||
| axis = (axis,) | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| graph_scope.set_input(input_x) | |||
| # cal logsoftmax. | |||
| def _expand(self, graph_builder): | |||
| input_x = self.inputs[0] | |||
| axis = self.attrs['axis'] | |||
| if isinstance(axis, int): | |||
| axis = (axis,) | |||
| max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) | |||
| data_sub = graph_builder.emit('Sub', [input_x, max_x]) | |||
| data_exp = graph_builder.emit('Exp', [data_sub]) | |||
| @@ -38,8 +35,4 @@ def expand_logsoftmax(expand_info): | |||
| log_expsum = graph_builder.emit('Log', [data_expsum]) | |||
| result = graph_builder.emit('Sub', [data_sub, log_expsum]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -13,34 +13,24 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for LogSoftmaxGrad""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_logsoftmaxgrad(expand_info): | |||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT) | |||
| @VLD.check_attrs('axis') | |||
| class LogSoftmaxGrad(Expander): | |||
| """LogSoftmaxGrad expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| axis = expand_info['attr']['axis'] | |||
| graph_builder = builder.GraphBuilder() | |||
| if isinstance(axis, int): | |||
| axis = (axis,) | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_logits = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| graph_scope.set_input(input_logits, input_dy) | |||
| def _expand(self, graph_builder): | |||
| input_logits, input_dy = self.inputs | |||
| axis = self.attrs['axis'] | |||
| if isinstance(axis, int): | |||
| axis = (axis,) | |||
| # cal logsoftmaxgrad. | |||
| softmax = graph_builder.emit('Exp', [input_logits]) | |||
| dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True}) | |||
| mul_result = graph_builder.emit('Mul', [softmax, dy_sum]) | |||
| result = graph_builder.emit('Sub', [input_dy, mul_result]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -13,40 +13,26 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for maximum_grad""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_maximumgrad(expand_info): | |||
| @VLD.check_all_formats_same | |||
| class MaximumGrad(Expander): | |||
| """MaximumGrad expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| attrs = expand_info['attr'] | |||
| grad_x = attrs['grad_x'] if 'grad_x' in attrs else True | |||
| grad_y = attrs['grad_y'] if 'grad_y' in attrs else True | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||
| graph_scope.set_input(input_x, input_y, input_dout) | |||
| x_dtype = input_x.dtype | |||
| # cal result | |||
| def _check(self): | |||
| if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True): | |||
| raise GKException("both grad_x and grad_y are False.") | |||
| return super()._check() | |||
| def _expand(self, graph_builder): | |||
| input_x, input_y, input_dout = self.inputs | |||
| ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y]) | |||
| ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': x_dtype}) | |||
| ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype}) | |||
| dx = graph_builder.emit('Mul', [ge_result, input_dout]) | |||
| dy = graph_builder.emit('Sub', [input_dout, dx]) | |||
| # set graph output according to grad_x and grad_y | |||
| if grad_x and grad_y: | |||
| graph_scope.set_output(dx, dy) | |||
| if grad_x and not grad_y: | |||
| graph_scope.set_output(dx) | |||
| if grad_y and not grad_x: | |||
| graph_scope.set_output(dy) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| # output two results, regardless of grad_x and grad_y | |||
| return dx, dy | |||
| @@ -13,41 +13,26 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for minimum_grad""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_minimumgrad(expand_info): | |||
| @VLD.check_all_formats_same | |||
| class MinimumGrad(Expander): | |||
| """MinimumGrad expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| attrs = expand_info['attr'] | |||
| grad_x = attrs['grad_x'] if 'grad_x' in attrs else True | |||
| grad_y = attrs['grad_y'] if 'grad_y' in attrs else True | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||
| graph_scope.set_input(input_x, input_y, input_dout) | |||
| x_dtype = input_x.dtype | |||
| def _check(self): | |||
| if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True): | |||
| raise GKException("both grad_x and grad_y are False.") | |||
| return super()._check() | |||
| def _expand(self, graph_builder): | |||
| input_x, input_y, input_dout = self.inputs | |||
| # cal result | |||
| le_result = graph_builder.emit('LessEqual', [input_x, input_y]) | |||
| le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': x_dtype}) | |||
| le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype}) | |||
| dx = graph_builder.emit('Mul', [le_result, input_dout]) | |||
| dy = graph_builder.emit('Sub', [input_dout, dx]) | |||
| # set graph output according to grad_x and grad_y | |||
| if grad_x and grad_y: | |||
| graph_scope.set_output(dx, dy) | |||
| if grad_x and not grad_y: | |||
| graph_scope.set_output(dx) | |||
| if grad_y and not grad_x: | |||
| graph_scope.set_output(dy) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| # output two results, regardless of grad_x and grad_y | |||
| return dx, dy | |||
| @@ -13,45 +13,30 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for reduce_mean""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_reducemean(expand_info): | |||
| @VLD.add_format(DF.DEFAULT) | |||
| @VLD.check_attrs('axis', 'keep_dims') | |||
| class ReduceMean(Expander): | |||
| """ReduceMean expander""" | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| attrs = expand_info['attr'] | |||
| axis = attrs['axis'] | |||
| keep_dims = attrs['keep_dims'] | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| x_shape = input_x.shape | |||
| graph_scope.set_input(input_x) | |||
| def _expand(self, graph_builder): | |||
| x = self.inputs[0] | |||
| axis = self.attrs['axis'] | |||
| keep_dims = self.attrs['keep_dims'] | |||
| # cal reduce_mean, when axis = None, reduce axis are all | |||
| all_shape = 1.0 | |||
| real_axis = [] | |||
| # cal reduce_mean, when axis is None, reduce all axes. | |||
| if not axis: | |||
| for i, shape in enumerate(x_shape): | |||
| real_axis.append(i) | |||
| all_shape *= shape | |||
| else: | |||
| for idx in axis: | |||
| all_shape *= x_shape[idx] | |||
| axis = list(range(len(x.shape))) | |||
| reduce_size = 1.0 | |||
| for idx in axis: | |||
| reduce_size *= x.shape[idx] | |||
| all_shape_value = graph_builder.value(input_x.dtype, all_shape) | |||
| reduce_size_value = graph_builder.value(x.dtype, reduce_size) | |||
| if not axis: | |||
| sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims}) | |||
| else: | |||
| sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims}) | |||
| result = graph_builder.emit('RealDiv', [sum_x, all_shape_value]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims}) | |||
| result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value]) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -13,26 +13,23 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for softmax""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_softmax(expand_info): | |||
| @VLD.add_format(DF.DEFAULT) | |||
| @VLD.check_attrs('axis') | |||
| class Softmax(Expander): | |||
| """Softmax expander""" | |||
| input_desc = expand_info['input_desc'][0] | |||
| axis = expand_info['attr']['axis'] | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| # cal softmax. | |||
| def _expand(self, graph_builder): | |||
| input_x = self.inputs[0] | |||
| axis = self.attrs['axis'] | |||
| max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) | |||
| data_sub = graph_builder.emit('Sub', [input_x, max_x]) | |||
| data_exp = graph_builder.emit('Exp', [data_sub]) | |||
| data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) | |||
| result = graph_builder.emit('RealDiv', [data_exp, data_expsum]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -13,33 +13,17 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for sqrtgrad""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_sqrtgrad(expand_info): | |||
| @VLD.check_all_formats_same | |||
| class SqrtGrad(Expander): | |||
| """SqrtGrad expander""" | |||
| # cal formula are: | |||
| # sqrt_grad(x, dout) is dout / (2 * x) | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_dout = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| graph_scope.set_input(input_x, input_dout) | |||
| # cal result | |||
| const_two = graph_builder.value(input_x.dtype, 2) | |||
| dividend = graph_builder.emit('Mul', [input_x, const_two]) | |||
| result = graph_builder.emit('RealDiv', [input_dout, dividend]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| def _expand(self, graph_builder): | |||
| # sqrt_grad(x, dout) = dout / (2 * x) | |||
| x, dout = self.inputs | |||
| const_two = graph_builder.value(x.dtype, 2) | |||
| dividend = graph_builder.emit('Mul', [x, const_two]) | |||
| result = graph_builder.emit('RealDiv', [dout, dividend]) | |||
| return result | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,24 +13,13 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for square""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from ._utils import Expander | |||
| def expand_square(expand_info): | |||
| class Square(Expander): | |||
| """Square expander""" | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| # create op. | |||
| result = graph_builder.emit('Mul', [input_x, input_x]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| def _expand(self, graph_builder): | |||
| x = self.inputs[0] | |||
| result = graph_builder.emit('Mul', [x, x]) | |||
| return result | |||
| @@ -13,34 +13,19 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for tanh_grad""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| ONE = 1.0 | |||
| def expand_tanhgrad(expand_info): | |||
| @VLD.check_all_formats_same | |||
| class TanhGrad(Expander): | |||
| """TanhGrad expander""" | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| const_one = graph_builder.value(input_y.dtype, ONE) | |||
| graph_scope.set_input(input_y, input_dy) | |||
| def _expand(self, graph_builder): | |||
| input_y, input_dy = self.inputs | |||
| # cal result | |||
| const_one = graph_builder.value(input_y.dtype, 1) | |||
| double_y = graph_builder.emit('Mul', [input_y, input_y]) | |||
| one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y]) | |||
| result = graph_builder.emit('Mul', [input_dy, one_sub_double_y]) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -14,25 +14,22 @@ | |||
| # =========================================================================== | |||
| """generate json desc for Tile""" | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| def expand_tile(expand_info): | |||
| @VLD.add_format(DF.DEFAULT) | |||
| @VLD.check_attrs('multiples') | |||
| class Tile(Expander): | |||
| """Tile expander""" | |||
| input_desc = expand_info['input_desc'][0] | |||
| multiples = expand_info['attr']['multiples'] | |||
| output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples) | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| # create op. | |||
| def _expand(self, graph_builder): | |||
| input_x = self.inputs[0] | |||
| multiples = self.attrs['multiples'] | |||
| output_shape, _, _, shape_compatible = builder.get_tile_output_shape(self.inputs[0].shape, multiples) | |||
| if shape_compatible: | |||
| result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) | |||
| else: | |||
| result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) | |||
| # set graph output. | |||
| graph_scope.set_output(result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| return result | |||
| @@ -55,6 +55,25 @@ class DataFormat: | |||
| NDHWC = "NDHWC" | |||
| class DataType: | |||
| """Data Type""" | |||
| FLOAT = "float" | |||
| FLOAT16 = "float16" | |||
| FLOAT32 = "float32" | |||
| FLOAT64 = "float64" | |||
| INT = "int" | |||
| INT8 = "int8" | |||
| INT16 = "int16" | |||
| INT32 = "int32" | |||
| INT64 = "int64" | |||
| UINT = "uint" | |||
| UINT8 = "uint8" | |||
| UINT16 = "uint16" | |||
| UINT32 = "uint32" | |||
| UINT64 = "uint64" | |||
| BOOL = "bool" | |||
| class Config: | |||
| R0 = 8.0 | |||
| UB_SIZE = 256 * 1024 | |||
| @@ -508,3 +527,9 @@ class AddControlBuddy(GraphVisitor): | |||
| for owner in self.buddies: | |||
| for op in self.buddies[owner]: | |||
| owner.add_buddy(op.output) | |||
| class GraphKernelUnsupportedException(Exception): | |||
| def __init__(self, message): | |||
| super().__init__() | |||
| self.message = message | |||
| @@ -90,7 +90,7 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { | |||
| } | |||
| std::string kernel_desc_str = py::cast<std::string>(ret); | |||
| if (kernel_desc_str.empty()) { | |||
| MS_LOG(ERROR) << "Jump expand node: " << node->fullname_with_scope(); | |||
| MS_LOG(INFO) << "Jump expand node: " << node->fullname_with_scope(); | |||
| return nullptr; | |||
| } | |||
| // decode json to func_graph. | |||
| @@ -131,11 +131,8 @@ AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func | |||
| kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); | |||
| auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); | |||
| SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); | |||
| std::string graph_kernel_flag; | |||
| std::for_each(kernel_nodes.begin(), kernel_nodes.end(), [&graph_kernel_flag](const AnfNodePtr &node) { | |||
| static_cast<void>(graph_kernel_flag.append(AnfAlgo::GetCNodeName(node)).append("_")); | |||
| }); | |||
| MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() << " with: " << graph_kernel_flag; | |||
| MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() | |||
| << " with: " << graph_kernel_node->fullname_with_scope(); | |||
| return graph_kernel_node; | |||
| } | |||
| @@ -152,14 +149,12 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { | |||
| continue; | |||
| } | |||
| MS_LOG(INFO) << "Expand process node: " << node->fullname_with_scope(); | |||
| MS_LOG(INFO) << "Expanding node: " << node->fullname_with_scope(); | |||
| auto new_func_graph = CreateExpandFuncGraph(node); | |||
| if (new_func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Decode fused nodes failed, " << node->fullname_with_scope(); | |||
| continue; | |||
| } | |||
| mng->AddFuncGraph(new_func_graph); | |||
| MS_LOG(DEBUG) << "decode fused nodes success."; | |||
| auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node); | |||
| new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node))); | |||