From: @dayschan Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -18,6 +18,16 @@ import json.decoder as jd | |||||
| import traceback | import traceback | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| import mindspore._extends.graph_kernel.expanders as expanders | import mindspore._extends.graph_kernel.expanders as expanders | ||||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException | |||||
| def create_expander(expand_info): | |||||
| """Create an expander according to op name""" | |||||
| op_name = str(expand_info['name']) | |||||
| if not hasattr(expanders, op_name): | |||||
| raise GraphKernelUnsupportedException("Generator do not support op: {}".format(op_name)) | |||||
| expander = getattr(expanders, op_name) | |||||
| return expander(expand_info) | |||||
| def extract_expand_info(kernel_info): | def extract_expand_info(kernel_info): | ||||
| @@ -46,20 +56,8 @@ def get_op_expander(json_str: str): | |||||
| kernel_info = json.loads(json_str) | kernel_info = json.loads(json_str) | ||||
| expand_info = extract_expand_info(kernel_info) | expand_info = extract_expand_info(kernel_info) | ||||
| processor = expand_info['process'] | |||||
| op_name = str(expand_info['name']).lower() | |||||
| expand_op_func_name = 'expand_' + op_name | |||||
| if not hasattr(expanders, expand_op_func_name): | |||||
| logger.error("Generator do not support op: {}".format(op_name)) | |||||
| return None | |||||
| expand_op_func = getattr(expanders, expand_op_func_name) | |||||
| # generate graph desc. | |||||
| graph = expand_op_func(expand_info) | |||||
| if graph is None: | |||||
| logger.error("Failed to generate graph of: {}".format(op_name)) | |||||
| return None | |||||
| graph.set_processor(processor) | |||||
| expander = create_expander(expand_info) | |||||
| graph = expander.run() | |||||
| # dump graph to json desc. | # dump graph to json desc. | ||||
| desc = graph.dump() | desc = graph.dump() | ||||
| @@ -69,3 +67,6 @@ def get_op_expander(json_str: str): | |||||
| logger.error("Failed to generate graph kernel op") | logger.error("Failed to generate graph kernel op") | ||||
| logger.error(traceback.format_exc()) | logger.error(traceback.format_exc()) | ||||
| return None | return None | ||||
| except GraphKernelUnsupportedException as e: | |||||
| logger.info(e.message) | |||||
| return "" | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -14,24 +14,24 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """expanders init""" | """expanders init""" | ||||
| from .gelu import expand_gelu | |||||
| from .gelu_grad import expand_gelugrad | |||||
| from .layernorm import expand_layernorm | |||||
| from .softmax import expand_softmax | |||||
| from .square import expand_square | |||||
| from .bias_add import expand_biasadd | |||||
| from .bias_add_grad import expand_biasaddgrad | |||||
| from .fused_adam import expand_fusedadam | |||||
| from .fused_adam_weight_decay import expand_fusedadamweightdecay | |||||
| from .reduce_mean import expand_reducemean | |||||
| from .tanh_grad import expand_tanhgrad | |||||
| from .maximum_grad import expand_maximumgrad | |||||
| from .minimum_grad import expand_minimumgrad | |||||
| from .dropout_grad import expand_dropoutgrad | |||||
| from .layernorm_grad import expand_layernormgrad | |||||
| from .logsoftmax import expand_logsoftmax | |||||
| from .logsoftmax_grad import expand_logsoftmaxgrad | |||||
| from .gkdropout import expand_gkdropout | |||||
| from .tile import expand_tile | |||||
| from .sqrt_grad import expand_sqrtgrad | |||||
| from .clip_by_norm_no_div_sum import expand_clipbynormnodivsum | |||||
| from .bias_add import BiasAdd | |||||
| from .bias_add_grad import BiasAddGrad | |||||
| from .clip_by_norm_no_div_sum import ClipByNormNoDivSum | |||||
| from .dropout_grad import DropoutGrad | |||||
| from .fused_adam import FusedAdam | |||||
| from .fused_adam_weight_decay import FusedAdamWeightDecay | |||||
| from .gelu import GeLU | |||||
| from .gelu_grad import GeLUGrad | |||||
| from .gkdropout import GkDropout | |||||
| from .layernorm import LayerNorm | |||||
| from .layernorm_grad import LayerNormGrad | |||||
| from .logsoftmax import LogSoftmax | |||||
| from .logsoftmax_grad import LogSoftmaxGrad | |||||
| from .maximum_grad import MaximumGrad | |||||
| from .minimum_grad import MinimumGrad | |||||
| from .reduce_mean import ReduceMean | |||||
| from .softmax import Softmax | |||||
| from .sqrt_grad import SqrtGrad | |||||
| from .square import Square | |||||
| from .tanh_grad import TanhGrad | |||||
| from .tile import Tile | |||||
| @@ -0,0 +1,146 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """GraphKernel expander utils""" | |||||
| from abc import ABCMeta, abstractmethod | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException | |||||
| class Expander: | |||||
| """ | |||||
| Expander is the base class of expanders. | |||||
| The method `_expand` should be overridden to implement the operator detail. | |||||
| """ | |||||
| __metaclass__ = ABCMeta | |||||
| def __init__(self, expand_info): | |||||
| self.name = expand_info["name"] | |||||
| self.inputs = expand_info["input_desc"] | |||||
| self.outputs = expand_info["output_desc"] | |||||
| self.attrs = expand_info["attr"] | |||||
| self.processor = expand_info["process"] | |||||
| def run(self): | |||||
| """ | |||||
| Expand the operator to a graph. | |||||
| `GraphKernelUnsupportedException` would be raised if check failed. | |||||
| """ | |||||
| self._check() | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope(self.name) as graph_scope: | |||||
| # transform input_desc to Tensor | |||||
| self.inputs = [graph_builder.tensor(inp['shape'], inp['data_type'], inp['format']) for inp in self.inputs] | |||||
| graph_scope.set_input(*self.inputs) | |||||
| outputs = self._expand(graph_builder) | |||||
| if isinstance(outputs, (list, tuple)): | |||||
| graph_scope.set_output(*outputs) | |||||
| else: | |||||
| graph_scope.set_output(outputs) | |||||
| graph = graph_builder.get()[0] | |||||
| graph.set_processor(self.processor) | |||||
| return graph | |||||
| def _check(self): | |||||
| """Check inputs""" | |||||
| @abstractmethod | |||||
| def _expand(self, graph_builder): | |||||
| """Expand operator, this function should be overridden in subclass""" | |||||
| raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__)) | |||||
| class ExpanderInfoValidator: | |||||
| """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders""" | |||||
| # pylint: disable=W0211 | |||||
| @staticmethod | |||||
| def _add_check_function(cls, func): | |||||
| """ | |||||
| Rewrite the function `_check` in class Expander | |||||
| to append the new `func` after the original checks. | |||||
| """ | |||||
| old_check = getattr(cls, "_check") | |||||
| def new_check(obj): | |||||
| old_check(obj) | |||||
| func(obj) | |||||
| setattr(cls, "_check", new_check) | |||||
| @staticmethod | |||||
| def add_format(*input_format): | |||||
| """ | |||||
| Add new supported format for the operator | |||||
| this function will add a list `__supported_formats` into the expander, | |||||
| saving the whitelist of formats that this op supports. | |||||
| it also rewrites the `_check` function to check the formats. | |||||
| """ | |||||
| format_list_name = "__supported_formats" | |||||
| def _check_format(obj): | |||||
| inp_formats = [inp['format'] for inp in obj.inputs] | |||||
| for formats in getattr(obj, format_list_name): | |||||
| if len(formats) != len(inp_formats): | |||||
| raise GKException("length of registered format doesn't match with the input of {}".format(obj.name)) | |||||
| if all([fmt == inp for fmt, inp in zip(formats, inp_formats)]): | |||||
| return | |||||
| raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name)) | |||||
| def wrapper(cls): | |||||
| if not issubclass(cls, Expander): | |||||
| raise Exception("{} should be subclass of Expander.".format(cls.__name__)) | |||||
| if not hasattr(cls, format_list_name): | |||||
| setattr(cls, format_list_name, list()) | |||||
| ExpanderInfoValidator._add_check_function(cls, _check_format) | |||||
| getattr(cls, format_list_name).append(input_format) | |||||
| return cls | |||||
| return wrapper | |||||
| @staticmethod | |||||
| def check_all_formats_same(cls): | |||||
| """Check that all formats are the same""" | |||||
| def _check_format(obj): | |||||
| inp_formats = [inp['format'] for inp in obj.inputs] | |||||
| if all([fmt == inp_formats[0] for fmt in inp_formats[1:]]): | |||||
| return | |||||
| raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format( | |||||
| ','.join(inp_formats), obj.name)) | |||||
| def wrapper(*args, **kargs): | |||||
| if not issubclass(cls, Expander): | |||||
| raise Exception("{} should be subclass of Expander.".format(cls.__name__)) | |||||
| ExpanderInfoValidator._add_check_function(cls, _check_format) | |||||
| return cls(*args, **kargs) | |||||
| return wrapper | |||||
| @staticmethod | |||||
| def check_attrs(*args): | |||||
| """Check the attrs exist""" | |||||
| def _check_attr(obj): | |||||
| for a in args: | |||||
| if a not in obj.attrs: | |||||
| raise GKException("attr '{}' does not exist.".format(a)) | |||||
| def wrapper(cls): | |||||
| if not issubclass(cls, Expander): | |||||
| raise Exception("{} should be subclass of Expander.".format(cls.__name__)) | |||||
| ExpanderInfoValidator._add_check_function(cls, _check_attr) | |||||
| return cls | |||||
| return wrapper | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,50 +13,34 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for bias_add""" | """generate json desc for bias_add""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_biasadd(expand_info): | |||||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.add_format(DF.NCHW, DF.DEFAULT) | |||||
| @VLD.add_format(DF.NHWC, DF.DEFAULT) | |||||
| class BiasAdd(Expander): | |||||
| """BiasAdd expander""" | """BiasAdd expander""" | ||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor( | |||||
| input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_y = graph_builder.tensor( | |||||
| input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| graph_scope.set_input(input_x, input_y) | |||||
| if input_x.data_format == "NCHW": | |||||
| input_y_expand = graph_builder.emit( | |||||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | |||||
| input_y_expand = graph_builder.emit( | |||||
| 'ExpandDims', [input_y_expand], attrs={'axis': 2}) | |||||
| def _expand(self, graph_builder): | |||||
| input_x, input_y = self.inputs | |||||
| if input_x.data_format == DF.NCHW: | |||||
| input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1}) | |||||
| input_y_expand = graph_builder.emit('ExpandDims', [input_y_expand], attrs={'axis': 2}) | |||||
| result = graph_builder.emit('Add', [input_x, input_y_expand]) | result = graph_builder.emit('Add', [input_x, input_y_expand]) | ||||
| elif input_x.data_format == "DefaultFormat": | |||||
| elif input_x.data_format == DF.DEFAULT: | |||||
| if len(input_x.shape) == 2: | if len(input_x.shape) == 2: | ||||
| result = graph_builder.emit('Add', [input_x, input_y]) | result = graph_builder.emit('Add', [input_x, input_y]) | ||||
| elif len(input_x.shape) == 3: | elif len(input_x.shape) == 3: | ||||
| input_y_expand = graph_builder.emit( | |||||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | |||||
| result = graph_builder.emit( | |||||
| 'Add', [input_x, input_y_expand]) | |||||
| else: | |||||
| input_y_expand = graph_builder.emit( | |||||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | |||||
| input_y_expand = graph_builder.emit( | |||||
| 'ExpandDims', [input_y_expand], attrs={'axis': 2}) | |||||
| result = graph_builder.emit( | |||||
| 'Add', [input_x, input_y_expand]) | |||||
| else: | |||||
| input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1}) | |||||
| result = graph_builder.emit('Add', [input_x, input_y_expand]) | |||||
| else: # len == 4 | |||||
| input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1}) | |||||
| input_y_expand = graph_builder.emit('ExpandDims', [input_y_expand], attrs={'axis': 2}) | |||||
| result = graph_builder.emit('Add', [input_x, input_y_expand]) | |||||
| else: # NHWC | |||||
| result = graph_builder.emit('Add', [input_x, input_y]) | result = graph_builder.emit('Add', [input_x, input_y]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,26 +13,25 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for bias_add""" | """generate json desc for bias_add""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_biasaddgrad(expand_info): | |||||
| @VLD.add_format(DF.DEFAULT) | |||||
| @VLD.add_format(DF.NHWC) | |||||
| @VLD.add_format(DF.NCHW) | |||||
| class BiasAddGrad(Expander): | |||||
| """BiasAddGrad expander""" | """BiasAddGrad expander""" | ||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor( | |||||
| input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| graph_scope.set_input(input_x) | |||||
| def _expand(self, graph_builder): | |||||
| input_x = self.inputs[0] | |||||
| reduce_axis = () | reduce_axis = () | ||||
| if input_x.data_format == 'NHWC': | if input_x.data_format == 'NHWC': | ||||
| reduce_axis = (0, 1, 2) | reduce_axis = (0, 1, 2) | ||||
| elif input_x.data_format == 'NCHW': | elif input_x.data_format == 'NCHW': | ||||
| reduce_axis = (0, 2, 3) | reduce_axis = (0, 2, 3) | ||||
| # Default format shape's length maybe equal 2 to 4, so different shape's length reduce axis are differnet | |||||
| # DefaultFormat shape's length should be from 2 to 4 | |||||
| else: | else: | ||||
| if len(input_x.shape) == 2: | if len(input_x.shape) == 2: | ||||
| reduce_axis = (0,) | reduce_axis = (0,) | ||||
| @@ -41,8 +40,4 @@ def expand_biasaddgrad(expand_info): | |||||
| else: | else: | ||||
| reduce_axis = (0, 2, 3) | reduce_axis = (0, 2, 3) | ||||
| result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) | result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,27 +13,15 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for ClipByNormNoDivSum""" | """generate json desc for ClipByNormNoDivSum""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_clipbynormnodivsum(expand_info): | |||||
| @VLD.check_all_formats_same | |||||
| class ClipByNormNoDivSum(Expander): | |||||
| """ClipByNormNoDivSum expander""" | """ClipByNormNoDivSum expander""" | ||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| input_desc_2 = expand_info['input_desc'][2] | |||||
| input_desc_3 = expand_info['input_desc'][3] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x0 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_x1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| input_x2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||||
| input_x3 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) | |||||
| graph_scope.set_input(input_x0, input_x1, input_x2, input_x3) | |||||
| def _expand(self, graph_builder): | |||||
| input_x0, input_x1, input_x2, input_x3 = self.inputs | |||||
| # cal result | # cal result | ||||
| greater_res = graph_builder.emit('Greater', [input_x0, input_x1], attrs={'fusion': 'SelectGT_000'}) | greater_res = graph_builder.emit('Greater', [input_x0, input_x1], attrs={'fusion': 'SelectGT_000'}) | ||||
| @@ -44,8 +32,4 @@ def expand_clipbynormnodivsum(expand_info): | |||||
| attrs={'fusion': 'SelectGT_000_end'}) | attrs={'fusion': 'SelectGT_000_end'}) | ||||
| result = graph_builder.emit('Maximum', [select_res1, input_x3]) | result = graph_builder.emit('Maximum', [select_res1, input_x3]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -13,27 +13,18 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for DropoutGrad""" | """generate json desc for DropoutGrad""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_dropoutgrad(expand_info): | |||||
| @VLD.check_all_formats_same | |||||
| @VLD.check_attrs('keep_prob') | |||||
| class DropoutGrad(Expander): | |||||
| """DropoutGrad expander""" | """DropoutGrad expander""" | ||||
| # get op info. | |||||
| dy_desc = expand_info['input_desc'][0] | |||||
| mask_desc = expand_info['input_desc'][1] | |||||
| keep_prob = expand_info['attr']['keep_prob'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) | |||||
| input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format']) | |||||
| graph_scope.set_input(input_dy, input_mask) | |||||
| def _expand(self, graph_builder): | |||||
| input_dy, input_mask = self.inputs | |||||
| keep_prob = self.attrs['keep_prob'] | |||||
| r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob) | r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob) | ||||
| # create op. | |||||
| result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) | result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) | ||||
| result = graph_builder.emit('Mul', [result, input_mask]) | result = graph_builder.emit('Mul', [result, input_mask]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,40 +13,16 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for fused_adam""" | """generate json desc for fused_adam""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_fusedadam(expand_info): | |||||
| """FusedAdma expander""" | |||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| input_desc_2 = expand_info['input_desc'][2] | |||||
| input_desc_3 = expand_info['input_desc'][3] | |||||
| input_desc_4 = expand_info['input_desc'][4] | |||||
| input_desc_5 = expand_info['input_desc'][5] | |||||
| input_desc_6 = expand_info['input_desc'][6] | |||||
| input_desc_7 = expand_info['input_desc'][7] | |||||
| input_desc_8 = expand_info['input_desc'][8] | |||||
| input_desc_9 = expand_info['input_desc'][9] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| @VLD.check_all_formats_same | |||||
| class FusedAdam(Expander): | |||||
| """FusedAdam expander""" | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||||
| one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) | |||||
| eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format']) | |||||
| lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format']) | |||||
| param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format']) | |||||
| m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format']) | |||||
| v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format']) | |||||
| gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format']) | |||||
| graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient) | |||||
| def _expand(self, graph_builder): | |||||
| beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs | |||||
| # compute result | |||||
| beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | ||||
| one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | ||||
| next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | ||||
| @@ -60,12 +36,9 @@ def expand_fusedadam(expand_info): | |||||
| update_with_lr = graph_builder.emit('Mul', [lr, update]) | update_with_lr = graph_builder.emit('Mul', [lr, update]) | ||||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | next_para = graph_builder.emit('Sub', [param, update_with_lr]) | ||||
| param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||||
| param_result = graph_builder.emit( | |||||
| 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||||
| param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True}) | param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True}) | ||||
| param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True}) | param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True}) | ||||
| # set graph output. | |||||
| graph_scope.set_output(param_result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return param_result | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,41 +13,15 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for fused_adam_weight_decay""" | """generate json desc for fused_adam_weight_decay""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_fusedadamweightdecay(expand_info): | |||||
| """FusedAdmaWeightDecay expander""" | |||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| input_desc_2 = expand_info['input_desc'][2] | |||||
| input_desc_3 = expand_info['input_desc'][3] | |||||
| input_desc_4 = expand_info['input_desc'][4] | |||||
| input_desc_5 = expand_info['input_desc'][5] | |||||
| input_desc_6 = expand_info['input_desc'][6] | |||||
| input_desc_7 = expand_info['input_desc'][7] | |||||
| input_desc_8 = expand_info['input_desc'][8] | |||||
| input_desc_9 = expand_info['input_desc'][9] | |||||
| input_desc_10 = expand_info['input_desc'][10] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| @VLD.check_all_formats_same | |||||
| class FusedAdamWeightDecay(Expander): | |||||
| """FusedAdamWeightDecay expander""" | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||||
| one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) | |||||
| eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format']) | |||||
| lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format']) | |||||
| param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format']) | |||||
| m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format']) | |||||
| v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format']) | |||||
| gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format']) | |||||
| weight_decay = graph_builder.tensor(input_desc_10['shape'], input_desc_10['data_type'], input_desc_10['format']) | |||||
| graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, | |||||
| eps, lr, param, m, v, gradient, weight_decay) | |||||
| def _expand(self, graph_builder): | |||||
| beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs | |||||
| # compute result | # compute result | ||||
| beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | ||||
| @@ -65,12 +39,9 @@ def expand_fusedadamweightdecay(expand_info): | |||||
| update_with_lr = graph_builder.emit('Mul', [lr, update]) | update_with_lr = graph_builder.emit('Mul', [lr, update]) | ||||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | next_para = graph_builder.emit('Sub', [param, update_with_lr]) | ||||
| para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||||
| para_result = graph_builder.emit( | |||||
| 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||||
| para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True}) | para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True}) | ||||
| para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True}) | para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True}) | ||||
| # set graph output. | |||||
| graph_scope.set_output(para_result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return para_result | |||||
| @@ -13,49 +13,36 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for gelu""" | """generate json desc for gelu""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from ._utils import Expander | |||||
| CSVALUE = 0.044715 | |||||
| CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) | |||||
| ONE = 1.0 | |||||
| HALF = 0.5 | |||||
| def expand_gelu(expand_info): | |||||
| class GeLU(Expander): | |||||
| """GeLU expander""" | """GeLU expander""" | ||||
| # cal formula are: | |||||
| # gelu(x) is 0.5 * x * (1.0 + tanh(y)) | |||||
| # y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | |||||
| CSVALUE = 0.044715 | |||||
| CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) | |||||
| # get op info. | |||||
| input_desc = expand_info['input_desc'][0] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| def _expand(self, graph_builder): | |||||
| # cal formula are: | |||||
| # gelu(x) is 0.5 * x * (1.0 + tanh(y)) | |||||
| # y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||||
| graph_scope.set_input(input_x) | |||||
| input_x = self.inputs[0] | |||||
| # cal y | # cal y | ||||
| mul_0 = graph_builder.emit('Mul', [input_x, input_x]) | mul_0 = graph_builder.emit('Mul', [input_x, input_x]) | ||||
| pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) | pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) | ||||
| const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE) | |||||
| const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE) | |||||
| mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) | mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) | ||||
| tanh_res = graph_builder.emit('Add', [input_x, mul_1]) | tanh_res = graph_builder.emit('Add', [input_x, mul_1]) | ||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI) | |||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI) | |||||
| y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) | y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) | ||||
| # cal gelu(x) | # cal gelu(x) | ||||
| tanh_y = graph_builder.emit('Tanh', [y]) | tanh_y = graph_builder.emit('Tanh', [y]) | ||||
| const_one = graph_builder.value(tanh_y.dtype, ONE) | |||||
| const_half = graph_builder.value(tanh_y.dtype, HALF) | |||||
| const_one = graph_builder.value(tanh_y.dtype, 1) | |||||
| const_half = graph_builder.value(tanh_y.dtype, 0.5) | |||||
| tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) | tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) | ||||
| mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | ||||
| result = graph_builder.emit('Mul', [const_half, mul_x]) | result = graph_builder.emit('Mul', [const_half, mul_x]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -13,43 +13,31 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for gelugrad""" | """generate json desc for gelugrad""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| CSVALUE = 0.044715 | |||||
| CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) | |||||
| CSVALUE_TRI = 0.134141 # CSVALUE * 3 | |||||
| ONE = 1.0 | |||||
| HALF = 0.5 | |||||
| def expand_gelugrad(expand_info): | |||||
| @VLD.check_all_formats_same | |||||
| class GeLUGrad(Expander): | |||||
| """GeLUGrad expander""" | """GeLUGrad expander""" | ||||
| # cal formula are: | |||||
| # gelu_grad(dy, x) is dy * y' | |||||
| # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right | |||||
| # tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | |||||
| # mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x) | |||||
| CSVALUE = 0.044715 | |||||
| CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) | |||||
| CSVALUE_TRI = 0.134141 # CSVALUE * 3 | |||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| input_desc_2 = expand_info['input_desc'][2] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| def _expand(self, graph_builder): | |||||
| # cal formula are: | |||||
| # gelu_grad(dy, x) is dy * y' | |||||
| # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right | |||||
| # tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | |||||
| # mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x) | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_dy = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||||
| graph_scope.set_input(input_dy, input_x, input_y) | |||||
| input_dy, input_x, _ = self.inputs | |||||
| # create some const var | # create some const var | ||||
| const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE) | |||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI) | |||||
| const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI) | |||||
| const_one = graph_builder.value(input_dy.dtype, ONE) | |||||
| const_half = graph_builder.value(input_dy.dtype, HALF) | |||||
| const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE) | |||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI) | |||||
| const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI) | |||||
| const_one = graph_builder.value(input_dy.dtype, 1) | |||||
| const_half = graph_builder.value(input_dy.dtype, 0.5) | |||||
| # cal mul_right | # cal mul_right | ||||
| mul_double = graph_builder.emit('Mul', [input_x, input_x]) | mul_double = graph_builder.emit('Mul', [input_x, input_x]) | ||||
| @@ -79,8 +67,4 @@ def expand_gelugrad(expand_info): | |||||
| result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final]) | result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final]) | ||||
| result = graph_builder.emit('Mul', [input_dy, result_tmp]) | result = graph_builder.emit('Mul', [input_dy, result_tmp]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -12,35 +12,29 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for GkDropOut""" | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| """generate json desc for GkDropout""" | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_gkdropout(expand_info): | |||||
| """GkDropOut expander""" | |||||
| # get op info. | |||||
| input_desc = expand_info['input_desc'][0] | |||||
| maks_desc = expand_info['input_desc'][1] | |||||
| keep_prob = expand_info['attr']['keep_prob'] | |||||
| @VLD.check_all_formats_same | |||||
| @VLD.check_attrs('keep_prob') | |||||
| class GkDropout(Expander): | |||||
| """GkDropout expander""" | |||||
| def _expand(self, graph_builder): | |||||
| input_x, input_mask = self.inputs | |||||
| keep_prob = self.attrs['keep_prob'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||||
| input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format']) | |||||
| graph_scope.set_input(input_x, input_mask) | |||||
| keep_prob_v = graph_builder.value(input_x.dtype, keep_prob) | |||||
| r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob) | r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob) | ||||
| keep_prob = graph_builder.value(input_x.dtype, keep_prob) | |||||
| if input_mask.dtype != input_x.dtype: | if input_mask.dtype != input_x.dtype: | ||||
| input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype}) | input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype}) | ||||
| mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type | |||||
| mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type | |||||
| mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype}) | mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype}) | ||||
| # compute result | # compute result | ||||
| result = graph_builder.emit('Mul', [r_keep_prob, input_x]) | result = graph_builder.emit('Mul', [r_keep_prob, input_x]) | ||||
| result = graph_builder.emit('Mul', [result, mask]) | result = graph_builder.emit('Mul', [result, mask]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result, mask) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result, mask | |||||
| @@ -13,38 +13,31 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for LayerNorm""" | """generate json desc for LayerNorm""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_layernorm(expand_info): | |||||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon') | |||||
| class LayerNorm(Expander): | |||||
| """LayerNorm expander""" | """LayerNorm expander""" | ||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| input_desc_2 = expand_info['input_desc'][2] | |||||
| attrs = expand_info['attr'] | |||||
| begin_norm_axis = attrs['begin_norm_axis'] | |||||
| epsilon = attrs['epsilon'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_gamma = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| input_beta = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||||
| def _expand(self, graph_builder): | |||||
| input_x, input_gamma, input_beta = self.inputs | |||||
| begin_norm_axis = self.attrs['begin_norm_axis'] | |||||
| epsilon = self.attrs['epsilon'] | |||||
| # Calculate the scaling ratio of the average | # Calculate the scaling ratio of the average | ||||
| shape_x = input_desc_0['shape'] | |||||
| if begin_norm_axis < 0: | if begin_norm_axis < 0: | ||||
| begin_norm_axis += len(shape_x) | |||||
| begin_norm_axis += len(input_x.shape) | |||||
| reduce_axis = () | reduce_axis = () | ||||
| for i, _ in enumerate(shape_x): | |||||
| for i, _ in enumerate(input_x.shape): | |||||
| if i > begin_norm_axis or i == begin_norm_axis: | if i > begin_norm_axis or i == begin_norm_axis: | ||||
| reduce_axis = reduce_axis + (i,) | reduce_axis = reduce_axis + (i,) | ||||
| reduce_elts = 1.0 | reduce_elts = 1.0 | ||||
| for i in reduce_axis: | for i in reduce_axis: | ||||
| reduce_elts *= shape_x[i] | |||||
| reduce_elts *= input_x.shape[i] | |||||
| mean_cof = 1.0 / reduce_elts | mean_cof = 1.0 / reduce_elts | ||||
| mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) | mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) | ||||
| @@ -70,8 +63,4 @@ def expand_layernorm(expand_info): | |||||
| scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul]) | scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul]) | ||||
| res = graph_builder.emit('Add', [scale_mul, input_beta]) | res = graph_builder.emit('Add', [scale_mul, input_beta]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(res, mean, variance) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return res, mean, variance | |||||
| @@ -13,42 +13,30 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for LayerNormGrad""" | """generate json desc for LayerNormGrad""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_layernormgrad(expand_info): | |||||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.check_attrs('begin_norm_axis', 'begin_params_axis') | |||||
| class LayerNormGrad(Expander): | |||||
| """LayerNormGrad expander""" | """LayerNormGrad expander""" | ||||
| # get op info. | |||||
| x_desc = expand_info['input_desc'][0] | |||||
| dy_desc = expand_info['input_desc'][1] | |||||
| var_desc = expand_info['input_desc'][2] | |||||
| mean_desc = expand_info['input_desc'][3] | |||||
| gamma_desc = expand_info['input_desc'][4] | |||||
| attrs = expand_info['attr'] | |||||
| begin_norm_axis = attrs['begin_norm_axis'] | |||||
| begin_params_axis = attrs['begin_params_axis'] | |||||
| epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-11 | |||||
| shape_x = x_desc['shape'] | |||||
| if begin_norm_axis < 0: | |||||
| begin_norm_axis += len(shape_x) | |||||
| if begin_params_axis < 0: | |||||
| begin_params_axis += len(shape_x) | |||||
| norm_axis = tuple(range(begin_norm_axis, len(shape_x))) | |||||
| param_axis = tuple(range(0, begin_params_axis)) | |||||
| reduce_size = 1.0 | |||||
| for i in norm_axis: | |||||
| reduce_size *= shape_x[i] | |||||
| def _expand(self, graph_builder): | |||||
| x, dy, variance, mean, gamma = self.inputs | |||||
| begin_norm_axis = self.attrs['begin_norm_axis'] | |||||
| begin_params_axis = self.attrs['begin_params_axis'] | |||||
| epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11 | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create input tensors. | |||||
| x = graph_builder.tensor(x_desc['shape'], x_desc['data_type'], x_desc['format']) | |||||
| dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) | |||||
| variance = graph_builder.tensor(var_desc['shape'], var_desc['data_type'], var_desc['format']) | |||||
| mean = graph_builder.tensor(mean_desc['shape'], mean_desc['data_type'], mean_desc['format']) | |||||
| gamma = graph_builder.tensor(gamma_desc['shape'], gamma_desc['data_type'], gamma_desc['format']) | |||||
| graph_scope.set_input(x, dy, variance, mean, gamma) | |||||
| if begin_norm_axis < 0: | |||||
| begin_norm_axis += len(x.shape) | |||||
| if begin_params_axis < 0: | |||||
| begin_params_axis += len(x.shape) | |||||
| norm_axis = tuple(range(begin_norm_axis, len(x.shape))) | |||||
| param_axis = tuple(range(0, begin_params_axis)) | |||||
| reduce_size = 1.0 | |||||
| for i in norm_axis: | |||||
| reduce_size *= x.shape[i] | |||||
| # set some constant val. | # set some constant val. | ||||
| eps = graph_builder.value(x.dtype, epsilon) | eps = graph_builder.value(x.dtype, epsilon) | ||||
| @@ -99,8 +87,4 @@ def expand_layernormgrad(expand_info): | |||||
| dx_tmp = graph_builder.emit('Add', [dx_1, dx_2]) | dx_tmp = graph_builder.emit('Add', [dx_1, dx_2]) | ||||
| dx = graph_builder.emit('Add', [dx_tmp, dx_3]) | dx = graph_builder.emit('Add', [dx_tmp, dx_3]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(dx, dg, db) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return dx, dg, db | |||||
| @@ -13,24 +13,21 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for LogSoftmax""" | """generate json desc for LogSoftmax""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_logsoftmax(expand_info): | |||||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.check_attrs('axis') | |||||
| class LogSoftmax(Expander): | |||||
| """LogSoftmax expander""" | """LogSoftmax expander""" | ||||
| # get op info. | |||||
| input_desc = expand_info['input_desc'][0] | |||||
| axis = expand_info['attr']['axis'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| if isinstance(axis, int): | |||||
| axis = (axis,) | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||||
| graph_scope.set_input(input_x) | |||||
| # cal logsoftmax. | |||||
| def _expand(self, graph_builder): | |||||
| input_x = self.inputs[0] | |||||
| axis = self.attrs['axis'] | |||||
| if isinstance(axis, int): | |||||
| axis = (axis,) | |||||
| max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) | max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) | ||||
| data_sub = graph_builder.emit('Sub', [input_x, max_x]) | data_sub = graph_builder.emit('Sub', [input_x, max_x]) | ||||
| data_exp = graph_builder.emit('Exp', [data_sub]) | data_exp = graph_builder.emit('Exp', [data_sub]) | ||||
| @@ -38,8 +35,4 @@ def expand_logsoftmax(expand_info): | |||||
| log_expsum = graph_builder.emit('Log', [data_expsum]) | log_expsum = graph_builder.emit('Log', [data_expsum]) | ||||
| result = graph_builder.emit('Sub', [data_sub, log_expsum]) | result = graph_builder.emit('Sub', [data_sub, log_expsum]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -13,34 +13,24 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for LogSoftmaxGrad""" | """generate json desc for LogSoftmaxGrad""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_logsoftmaxgrad(expand_info): | |||||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.check_attrs('axis') | |||||
| class LogSoftmaxGrad(Expander): | |||||
| """LogSoftmaxGrad expander""" | """LogSoftmaxGrad expander""" | ||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| axis = expand_info['attr']['axis'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| if isinstance(axis, int): | |||||
| axis = (axis,) | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_logits = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| graph_scope.set_input(input_logits, input_dy) | |||||
| def _expand(self, graph_builder): | |||||
| input_logits, input_dy = self.inputs | |||||
| axis = self.attrs['axis'] | |||||
| if isinstance(axis, int): | |||||
| axis = (axis,) | |||||
| # cal logsoftmaxgrad. | |||||
| softmax = graph_builder.emit('Exp', [input_logits]) | softmax = graph_builder.emit('Exp', [input_logits]) | ||||
| dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True}) | dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True}) | ||||
| mul_result = graph_builder.emit('Mul', [softmax, dy_sum]) | mul_result = graph_builder.emit('Mul', [softmax, dy_sum]) | ||||
| result = graph_builder.emit('Sub', [input_dy, mul_result]) | result = graph_builder.emit('Sub', [input_dy, mul_result]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -13,40 +13,26 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for maximum_grad""" | """generate json desc for maximum_grad""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_maximumgrad(expand_info): | |||||
| @VLD.check_all_formats_same | |||||
| class MaximumGrad(Expander): | |||||
| """MaximumGrad expander""" | """MaximumGrad expander""" | ||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| input_desc_2 = expand_info['input_desc'][2] | |||||
| attrs = expand_info['attr'] | |||||
| grad_x = attrs['grad_x'] if 'grad_x' in attrs else True | |||||
| grad_y = attrs['grad_y'] if 'grad_y' in attrs else True | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||||
| graph_scope.set_input(input_x, input_y, input_dout) | |||||
| x_dtype = input_x.dtype | |||||
| # cal result | |||||
| def _check(self): | |||||
| if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True): | |||||
| raise GKException("both grad_x and grad_y are False.") | |||||
| return super()._check() | |||||
| def _expand(self, graph_builder): | |||||
| input_x, input_y, input_dout = self.inputs | |||||
| ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y]) | ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y]) | ||||
| ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': x_dtype}) | |||||
| ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype}) | |||||
| dx = graph_builder.emit('Mul', [ge_result, input_dout]) | dx = graph_builder.emit('Mul', [ge_result, input_dout]) | ||||
| dy = graph_builder.emit('Sub', [input_dout, dx]) | dy = graph_builder.emit('Sub', [input_dout, dx]) | ||||
| # set graph output according to grad_x and grad_y | |||||
| if grad_x and grad_y: | |||||
| graph_scope.set_output(dx, dy) | |||||
| if grad_x and not grad_y: | |||||
| graph_scope.set_output(dx) | |||||
| if grad_y and not grad_x: | |||||
| graph_scope.set_output(dy) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| # output two results, regardless of grad_x and grad_y | |||||
| return dx, dy | |||||
| @@ -13,41 +13,26 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for minimum_grad""" | """generate json desc for minimum_grad""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_minimumgrad(expand_info): | |||||
| @VLD.check_all_formats_same | |||||
| class MinimumGrad(Expander): | |||||
| """MinimumGrad expander""" | """MinimumGrad expander""" | ||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| input_desc_2 = expand_info['input_desc'][2] | |||||
| attrs = expand_info['attr'] | |||||
| grad_x = attrs['grad_x'] if 'grad_x' in attrs else True | |||||
| grad_y = attrs['grad_y'] if 'grad_y' in attrs else True | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||||
| graph_scope.set_input(input_x, input_y, input_dout) | |||||
| x_dtype = input_x.dtype | |||||
| def _check(self): | |||||
| if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True): | |||||
| raise GKException("both grad_x and grad_y are False.") | |||||
| return super()._check() | |||||
| def _expand(self, graph_builder): | |||||
| input_x, input_y, input_dout = self.inputs | |||||
| # cal result | |||||
| le_result = graph_builder.emit('LessEqual', [input_x, input_y]) | le_result = graph_builder.emit('LessEqual', [input_x, input_y]) | ||||
| le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': x_dtype}) | |||||
| le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype}) | |||||
| dx = graph_builder.emit('Mul', [le_result, input_dout]) | dx = graph_builder.emit('Mul', [le_result, input_dout]) | ||||
| dy = graph_builder.emit('Sub', [input_dout, dx]) | dy = graph_builder.emit('Sub', [input_dout, dx]) | ||||
| # set graph output according to grad_x and grad_y | |||||
| if grad_x and grad_y: | |||||
| graph_scope.set_output(dx, dy) | |||||
| if grad_x and not grad_y: | |||||
| graph_scope.set_output(dx) | |||||
| if grad_y and not grad_x: | |||||
| graph_scope.set_output(dy) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| # output two results, regardless of grad_x and grad_y | |||||
| return dx, dy | |||||
| @@ -13,45 +13,30 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for reduce_mean""" | """generate json desc for reduce_mean""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_reducemean(expand_info): | |||||
| @VLD.add_format(DF.DEFAULT) | |||||
| @VLD.check_attrs('axis', 'keep_dims') | |||||
| class ReduceMean(Expander): | |||||
| """ReduceMean expander""" | """ReduceMean expander""" | ||||
| # get op info. | |||||
| input_desc = expand_info['input_desc'][0] | |||||
| attrs = expand_info['attr'] | |||||
| axis = attrs['axis'] | |||||
| keep_dims = attrs['keep_dims'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||||
| x_shape = input_x.shape | |||||
| graph_scope.set_input(input_x) | |||||
| def _expand(self, graph_builder): | |||||
| x = self.inputs[0] | |||||
| axis = self.attrs['axis'] | |||||
| keep_dims = self.attrs['keep_dims'] | |||||
| # cal reduce_mean, when axis = None, reduce axis are all | |||||
| all_shape = 1.0 | |||||
| real_axis = [] | |||||
| # cal reduce_mean, when axis is None, reduce all axes. | |||||
| if not axis: | if not axis: | ||||
| for i, shape in enumerate(x_shape): | |||||
| real_axis.append(i) | |||||
| all_shape *= shape | |||||
| else: | |||||
| for idx in axis: | |||||
| all_shape *= x_shape[idx] | |||||
| axis = list(range(len(x.shape))) | |||||
| reduce_size = 1.0 | |||||
| for idx in axis: | |||||
| reduce_size *= x.shape[idx] | |||||
| all_shape_value = graph_builder.value(input_x.dtype, all_shape) | |||||
| reduce_size_value = graph_builder.value(x.dtype, reduce_size) | |||||
| if not axis: | |||||
| sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims}) | |||||
| else: | |||||
| sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims}) | |||||
| result = graph_builder.emit('RealDiv', [sum_x, all_shape_value]) | |||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims}) | |||||
| result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value]) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -13,26 +13,23 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for softmax""" | """generate json desc for softmax""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_softmax(expand_info): | |||||
| @VLD.add_format(DF.DEFAULT) | |||||
| @VLD.check_attrs('axis') | |||||
| class Softmax(Expander): | |||||
| """Softmax expander""" | """Softmax expander""" | ||||
| input_desc = expand_info['input_desc'][0] | |||||
| axis = expand_info['attr']['axis'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||||
| # cal softmax. | |||||
| def _expand(self, graph_builder): | |||||
| input_x = self.inputs[0] | |||||
| axis = self.attrs['axis'] | |||||
| max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) | max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) | ||||
| data_sub = graph_builder.emit('Sub', [input_x, max_x]) | data_sub = graph_builder.emit('Sub', [input_x, max_x]) | ||||
| data_exp = graph_builder.emit('Exp', [data_sub]) | data_exp = graph_builder.emit('Exp', [data_sub]) | ||||
| data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) | data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) | ||||
| result = graph_builder.emit('RealDiv', [data_exp, data_expsum]) | result = graph_builder.emit('RealDiv', [data_exp, data_expsum]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -13,33 +13,17 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for sqrtgrad""" | """generate json desc for sqrtgrad""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_sqrtgrad(expand_info): | |||||
| @VLD.check_all_formats_same | |||||
| class SqrtGrad(Expander): | |||||
| """SqrtGrad expander""" | """SqrtGrad expander""" | ||||
| # cal formula are: | |||||
| # sqrt_grad(x, dout) is dout / (2 * x) | |||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_dout = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| graph_scope.set_input(input_x, input_dout) | |||||
| # cal result | |||||
| const_two = graph_builder.value(input_x.dtype, 2) | |||||
| dividend = graph_builder.emit('Mul', [input_x, const_two]) | |||||
| result = graph_builder.emit('RealDiv', [input_dout, dividend]) | |||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| def _expand(self, graph_builder): | |||||
| # sqrt_grad(x, dout) = dout / (2 * x) | |||||
| x, dout = self.inputs | |||||
| const_two = graph_builder.value(x.dtype, 2) | |||||
| dividend = graph_builder.emit('Mul', [x, const_two]) | |||||
| result = graph_builder.emit('RealDiv', [dout, dividend]) | |||||
| return result | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,24 +13,13 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for square""" | """generate json desc for square""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from ._utils import Expander | |||||
| def expand_square(expand_info): | |||||
| class Square(Expander): | |||||
| """Square expander""" | """Square expander""" | ||||
| # get op info. | |||||
| input_desc = expand_info['input_desc'][0] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||||
| # create op. | |||||
| result = graph_builder.emit('Mul', [input_x, input_x]) | |||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| def _expand(self, graph_builder): | |||||
| x = self.inputs[0] | |||||
| result = graph_builder.emit('Mul', [x, x]) | |||||
| return result | |||||
| @@ -13,34 +13,19 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for tanh_grad""" | """generate json desc for tanh_grad""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| ONE = 1.0 | |||||
| def expand_tanhgrad(expand_info): | |||||
| @VLD.check_all_formats_same | |||||
| class TanhGrad(Expander): | |||||
| """TanhGrad expander""" | """TanhGrad expander""" | ||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| const_one = graph_builder.value(input_y.dtype, ONE) | |||||
| graph_scope.set_input(input_y, input_dy) | |||||
| def _expand(self, graph_builder): | |||||
| input_y, input_dy = self.inputs | |||||
| # cal result | |||||
| const_one = graph_builder.value(input_y.dtype, 1) | |||||
| double_y = graph_builder.emit('Mul', [input_y, input_y]) | double_y = graph_builder.emit('Mul', [input_y, input_y]) | ||||
| one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y]) | one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y]) | ||||
| result = graph_builder.emit('Mul', [input_dy, one_sub_double_y]) | result = graph_builder.emit('Mul', [input_dy, one_sub_double_y]) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -14,25 +14,22 @@ | |||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for Tile""" | """generate json desc for Tile""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | from mindspore._extends.graph_kernel.model import model_builder as builder | ||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| def expand_tile(expand_info): | |||||
| @VLD.add_format(DF.DEFAULT) | |||||
| @VLD.check_attrs('multiples') | |||||
| class Tile(Expander): | |||||
| """Tile expander""" | """Tile expander""" | ||||
| input_desc = expand_info['input_desc'][0] | |||||
| multiples = expand_info['attr']['multiples'] | |||||
| output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples) | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||||
| # create op. | |||||
| def _expand(self, graph_builder): | |||||
| input_x = self.inputs[0] | |||||
| multiples = self.attrs['multiples'] | |||||
| output_shape, _, _, shape_compatible = builder.get_tile_output_shape(self.inputs[0].shape, multiples) | |||||
| if shape_compatible: | if shape_compatible: | ||||
| result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) | result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) | ||||
| else: | else: | ||||
| result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) | result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) | ||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| return result | |||||
| @@ -55,6 +55,25 @@ class DataFormat: | |||||
| NDHWC = "NDHWC" | NDHWC = "NDHWC" | ||||
| class DataType: | |||||
| """Data Type""" | |||||
| FLOAT = "float" | |||||
| FLOAT16 = "float16" | |||||
| FLOAT32 = "float32" | |||||
| FLOAT64 = "float64" | |||||
| INT = "int" | |||||
| INT8 = "int8" | |||||
| INT16 = "int16" | |||||
| INT32 = "int32" | |||||
| INT64 = "int64" | |||||
| UINT = "uint" | |||||
| UINT8 = "uint8" | |||||
| UINT16 = "uint16" | |||||
| UINT32 = "uint32" | |||||
| UINT64 = "uint64" | |||||
| BOOL = "bool" | |||||
| class Config: | class Config: | ||||
| R0 = 8.0 | R0 = 8.0 | ||||
| UB_SIZE = 256 * 1024 | UB_SIZE = 256 * 1024 | ||||
| @@ -508,3 +527,9 @@ class AddControlBuddy(GraphVisitor): | |||||
| for owner in self.buddies: | for owner in self.buddies: | ||||
| for op in self.buddies[owner]: | for op in self.buddies[owner]: | ||||
| owner.add_buddy(op.output) | owner.add_buddy(op.output) | ||||
| class GraphKernelUnsupportedException(Exception): | |||||
| def __init__(self, message): | |||||
| super().__init__() | |||||
| self.message = message | |||||
| @@ -90,7 +90,7 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { | |||||
| } | } | ||||
| std::string kernel_desc_str = py::cast<std::string>(ret); | std::string kernel_desc_str = py::cast<std::string>(ret); | ||||
| if (kernel_desc_str.empty()) { | if (kernel_desc_str.empty()) { | ||||
| MS_LOG(ERROR) << "Jump expand node: " << node->fullname_with_scope(); | |||||
| MS_LOG(INFO) << "Jump expand node: " << node->fullname_with_scope(); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // decode json to func_graph. | // decode json to func_graph. | ||||
| @@ -131,11 +131,8 @@ AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func | |||||
| kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); | kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); | ||||
| auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); | auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); | ||||
| SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); | SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); | ||||
| std::string graph_kernel_flag; | |||||
| std::for_each(kernel_nodes.begin(), kernel_nodes.end(), [&graph_kernel_flag](const AnfNodePtr &node) { | |||||
| static_cast<void>(graph_kernel_flag.append(AnfAlgo::GetCNodeName(node)).append("_")); | |||||
| }); | |||||
| MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() << " with: " << graph_kernel_flag; | |||||
| MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() | |||||
| << " with: " << graph_kernel_node->fullname_with_scope(); | |||||
| return graph_kernel_node; | return graph_kernel_node; | ||||
| } | } | ||||
| @@ -152,14 +149,12 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Expand process node: " << node->fullname_with_scope(); | |||||
| MS_LOG(INFO) << "Expanding node: " << node->fullname_with_scope(); | |||||
| auto new_func_graph = CreateExpandFuncGraph(node); | auto new_func_graph = CreateExpandFuncGraph(node); | ||||
| if (new_func_graph == nullptr) { | if (new_func_graph == nullptr) { | ||||
| MS_LOG(ERROR) << "Decode fused nodes failed, " << node->fullname_with_scope(); | |||||
| continue; | continue; | ||||
| } | } | ||||
| mng->AddFuncGraph(new_func_graph); | mng->AddFuncGraph(new_func_graph); | ||||
| MS_LOG(DEBUG) << "decode fused nodes success."; | |||||
| auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node); | auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node); | ||||
| new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node))); | new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node))); | ||||