From e0e6c39eae4273350678aef1360d91b9e0625b7d Mon Sep 17 00:00:00 2001 From: dayschan Date: Thu, 4 Feb 2021 20:35:41 +0800 Subject: [PATCH] Refactor GraphKernelExpander (1st submission) Decoupled from the front-end interfaces. 1. Removed the call to "Renormalize". Completed the infer-format in model_builder. Only used the device shape and device format to infer an abstract shape without considering padding. 2. Removed the call to python's Primitive interfaces. The "Renormalize" relies on the PrimitivePy, so they can be removed together. After that, the functions "ConstAttrToInput", "DeleteAttrInInput" and related can be removed. 3. Reuse the AkgKernelJsonGenerator in GraphKernelExpander. 1) set the attribute "extract_opinfo_from_anf" to true, so that the generator can handle the basic operator with anf info. 2) added a function "extract_expand_info" in expander.py to convert the json into a more friendly format. The attrs was converted to a dict instead of a list. 4. Scalars only support DefaultFormat. Removed the argument "format" from graph_builder.value 5. Moved the expander op list from graph_kernel_helper.cc to graph_kernel_expander.cc --- mindspore/_extends/graph_kernel/expander.py | 31 ++- .../graph_kernel/expanders/dropout_grad.py | 13 +- .../_extends/graph_kernel/expanders/gelu.py | 11 +- .../graph_kernel/expanders/gelu_grad.py | 13 +- .../graph_kernel/expanders/gkdropout.py | 17 +- .../graph_kernel/expanders/layernorm.py | 18 +- .../graph_kernel/expanders/layernorm_grad.py | 30 +- .../graph_kernel/expanders/logsoftmax.py | 8 +- .../graph_kernel/expanders/logsoftmax_grad.py | 8 +- .../graph_kernel/expanders/maximum_grad.py | 14 +- .../graph_kernel/expanders/minimum_grad.py | 13 +- .../graph_kernel/expanders/reduce_mean.py | 17 +- .../graph_kernel/expanders/softmax.py | 12 +- .../graph_kernel/expanders/sqrt_grad.py | 4 +- .../graph_kernel/expanders/tanh_grad.py | 4 +- .../_extends/graph_kernel/expanders/tile.py | 13 +- .../graph_kernel/model/model_builder.py | 31 ++- .../akg/akg_kernel_json_decoder.cc | 157 ++++++----- .../akg/akg_kernel_json_generator.cc | 28 +- .../akg/akg_kernel_json_generator.h | 1 + .../graph_kernel/add_atomic_clean_gpu.cc | 10 +- .../graph_kernel/graph_kernel_expander.cc | 134 +++------ .../graph_kernel/graph_kernel_expander.h | 5 +- .../graph_kernel/graph_kernel_helper.cc | 263 +----------------- .../graph_kernel/graph_kernel_helper.h | 2 - 25 files changed, 258 insertions(+), 599 deletions(-) diff --git a/mindspore/_extends/graph_kernel/expander.py b/mindspore/_extends/graph_kernel/expander.py index a0872e88d0..c4021c05b0 100644 --- a/mindspore/_extends/graph_kernel/expander.py +++ b/mindspore/_extends/graph_kernel/expander.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,18 +20,31 @@ from mindspore import log as logger import mindspore._extends.graph_kernel.expanders as expanders +def extract_expand_info(kernel_info): + """Convert the json into a more friendly format""" + input_desc = [] + if 'input_desc' in kernel_info and kernel_info['input_desc']: + for desc in kernel_info['input_desc']: + input_desc += desc + attrs = {} + if 'attr' in kernel_info and kernel_info['attr']: + for attr in kernel_info["attr"]: + attrs[attr["name"]] = attr["value"] + expand_info = { + "name": kernel_info["name"], + "input_desc": input_desc, + "output_desc": kernel_info["output_desc"], + "attr": attrs, + "process": kernel_info["process"], + } + return expand_info + + def get_op_expander(json_str: str): """get op expander by json info""" try: kernel_info = json.loads(json_str) - expand_info = kernel_info['expand_info'] - - if 'name' not in expand_info: - logger.error("expand info have no op name") - return None - if 'process' not in expand_info: - logger.error("expand info have no processor info") - return None + expand_info = extract_expand_info(kernel_info) processor = expand_info['process'] op_name = str(expand_info['name']).lower() diff --git a/mindspore/_extends/graph_kernel/expanders/dropout_grad.py b/mindspore/_extends/graph_kernel/expanders/dropout_grad.py index a18d2f1ff8..960ab2e2c0 100644 --- a/mindspore/_extends/graph_kernel/expanders/dropout_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/dropout_grad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,20 +21,15 @@ def expand_dropoutgrad(expand_info): # get op info. dy_desc = expand_info['input_desc'][0] mask_desc = expand_info['input_desc'][1] - keep_prob = None - for attr in expand_info['attr']: - if 'keep_prob' in attr: - keep_prob = attr['keep_prob'] - if keep_prob is None: - raise RuntimeError("keep_prob does not exist in attrs.") - # generate a graph. + keep_prob = expand_info['attr']['keep_prob'] + graph_builder = builder.GraphBuilder() with graph_builder.graph_scope('main') as graph_scope: # create tensor input. input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format']) graph_scope.set_input(input_dy, input_mask) - r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob, "DefaultFormat") + r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob) # create op. result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) result = graph_builder.emit('Mul', [result, input_mask]) diff --git a/mindspore/_extends/graph_kernel/expanders/gelu.py b/mindspore/_extends/graph_kernel/expanders/gelu.py index 661d0305d4..4b9760575d 100644 --- a/mindspore/_extends/graph_kernel/expanders/gelu.py +++ b/mindspore/_extends/graph_kernel/expanders/gelu.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,17 +40,16 @@ def expand_gelu(expand_info): # cal y mul_0 = graph_builder.emit('Mul', [input_x, input_x]) pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) - const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format']) + const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE) mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) tanh_res = graph_builder.emit('Add', [input_x, mul_1]) - const_csvalue_sqrt_two_div_pi = graph_builder.value( - tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format']) + const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI) y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) # cal gelu(x) tanh_y = graph_builder.emit('Tanh', [y]) - const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format']) - const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format']) + const_one = graph_builder.value(tanh_y.dtype, ONE) + const_half = graph_builder.value(tanh_y.dtype, HALF) tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) result = graph_builder.emit('Mul', [const_half, mul_x]) diff --git a/mindspore/_extends/graph_kernel/expanders/gelu_grad.py b/mindspore/_extends/graph_kernel/expanders/gelu_grad.py index 00372d4259..096c5ebfa4 100644 --- a/mindspore/_extends/graph_kernel/expanders/gelu_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/gelu_grad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,12 +45,11 @@ def expand_gelugrad(expand_info): graph_scope.set_input(input_dy, input_x, input_y) # create some const var - const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format']) - const_csvalue_sqrt_two_div_pi = graph_builder.value( - input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc_0['format']) - const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI, input_desc_0['format']) - const_one = graph_builder.value(input_dy.dtype, ONE, input_desc_0['format']) - const_half = graph_builder.value(input_dy.dtype, HALF, input_desc_0['format']) + const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE) + const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI) + const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI) + const_one = graph_builder.value(input_dy.dtype, ONE) + const_half = graph_builder.value(input_dy.dtype, HALF) # cal mul_right mul_double = graph_builder.emit('Mul', [input_x, input_x]) diff --git a/mindspore/_extends/graph_kernel/expanders/gkdropout.py b/mindspore/_extends/graph_kernel/expanders/gkdropout.py index d761681c9a..e2da7b5d2b 100644 --- a/mindspore/_extends/graph_kernel/expanders/gkdropout.py +++ b/mindspore/_extends/graph_kernel/expanders/gkdropout.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,25 +21,20 @@ def expand_gkdropout(expand_info): # get op info. input_desc = expand_info['input_desc'][0] maks_desc = expand_info['input_desc'][1] - keep_prob = None - for attr in expand_info['attr']: - if 'keep_prob' in attr: - keep_prob = attr['keep_prob'] - if keep_prob is None: - raise RuntimeError("keep_prob does not exist in attrs.") - # generate a graph. + keep_prob = expand_info['attr']['keep_prob'] + graph_builder = builder.GraphBuilder() with graph_builder.graph_scope('main') as graph_scope: # create tensor input. input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format']) graph_scope.set_input(input_x, input_mask) - keep_prob_v = graph_builder.value(input_x.dtype, keep_prob, "DefaultFormat") - r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob, "DefaultFormat") + keep_prob_v = graph_builder.value(input_x.dtype, keep_prob) + r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob) if input_mask.dtype != input_x.dtype: input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype}) - mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type + mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype}) # compute result diff --git a/mindspore/_extends/graph_kernel/expanders/layernorm.py b/mindspore/_extends/graph_kernel/expanders/layernorm.py index 9089a9f06b..79cea1a7d0 100644 --- a/mindspore/_extends/graph_kernel/expanders/layernorm.py +++ b/mindspore/_extends/graph_kernel/expanders/layernorm.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,16 +23,10 @@ def expand_layernorm(expand_info): input_desc_1 = expand_info['input_desc'][1] input_desc_2 = expand_info['input_desc'][2] attrs = expand_info['attr'] - begin_norm_axis = None - epsilon = None - for item in attrs: - if 'begin_norm_axis' in item: - begin_norm_axis = item['begin_norm_axis'] - if 'epsilon' in item: - epsilon = item['epsilon'] - graph_builder = builder.GraphBuilder() + begin_norm_axis = attrs['begin_norm_axis'] + epsilon = attrs['epsilon'] - # generate a graph. + graph_builder = builder.GraphBuilder() with graph_builder.graph_scope('main') as graph_scope: # create tensor input. input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) @@ -52,7 +46,7 @@ def expand_layernorm(expand_info): for i in reduce_axis: reduce_elts *= shape_x[i] mean_cof = 1.0 / reduce_elts - mean_cof_v = graph_builder.value(input_x.dtype, mean_cof, input_x.data_format) + mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) # Calculate mean mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) @@ -67,7 +61,7 @@ def expand_layernorm(expand_info): # Calculate normalize normalize_sub = graph_builder.emit('Sub', [input_x, mean]) - epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format) + epsilon_v = graph_builder.value(input_x.dtype, epsilon) normalize_add = graph_builder.emit('Add', [variance, epsilon_v]) normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add]) normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) diff --git a/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py b/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py index 35eca7e7d8..2a2dbb3f0f 100644 --- a/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,16 +24,10 @@ def expand_layernormgrad(expand_info): var_desc = expand_info['input_desc'][2] mean_desc = expand_info['input_desc'][3] gamma_desc = expand_info['input_desc'][4] - begin_norm_axis = None - begin_params_axis = None - epsilon = 1e-11 - for item in expand_info['attr']: - if 'begin_norm_axis' in item: - begin_norm_axis = item['begin_norm_axis'] - if 'begin_params_axis' in item: - begin_params_axis = item['begin_params_axis'] - if 'epsilon' in item: - epsilon = item['epsilon'] + attrs = expand_info['attr'] + begin_norm_axis = attrs['begin_norm_axis'] + begin_params_axis = attrs['begin_params_axis'] + epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-11 shape_x = x_desc['shape'] if begin_norm_axis < 0: @@ -57,13 +51,13 @@ def expand_layernormgrad(expand_info): graph_scope.set_input(x, dy, variance, mean, gamma) # set some constant val. - eps = graph_builder.value(x.dtype, epsilon, x.data_format) - const_one = graph_builder.value(x.dtype, 1.0, x.data_format) - const_neg_half = graph_builder.value(x.dtype, -0.5, x.data_format) - const_neg_two = graph_builder.value(x.dtype, -2.0, x.data_format) - const_two = graph_builder.value(x.dtype, 2.0, x.data_format) - const_neg_one = graph_builder.value(x.dtype, -1.0, x.data_format) - mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format) + eps = graph_builder.value(x.dtype, epsilon) + const_one = graph_builder.value(x.dtype, 1.0) + const_neg_half = graph_builder.value(x.dtype, -0.5) + const_neg_two = graph_builder.value(x.dtype, -2.0) + const_two = graph_builder.value(x.dtype, 2.0) + const_neg_one = graph_builder.value(x.dtype, -1.0) + mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size)) # cal dg db var_eps = graph_builder.emit('Add', [variance, eps]) diff --git a/mindspore/_extends/graph_kernel/expanders/logsoftmax.py b/mindspore/_extends/graph_kernel/expanders/logsoftmax.py index 7bfffd0ef8..57e57ca463 100644 --- a/mindspore/_extends/graph_kernel/expanders/logsoftmax.py +++ b/mindspore/_extends/graph_kernel/expanders/logsoftmax.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,11 +20,7 @@ def expand_logsoftmax(expand_info): """LogSoftmax expander""" # get op info. input_desc = expand_info['input_desc'][0] - attrs = expand_info['attr'] - axis = None - for item in attrs: - if 'axis' in item: - axis = item['axis'] + axis = expand_info['attr']['axis'] graph_builder = builder.GraphBuilder() if isinstance(axis, int): axis = (axis,) diff --git a/mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py b/mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py index 6a4d925f82..350e39c431 100644 --- a/mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,11 +21,7 @@ def expand_logsoftmaxgrad(expand_info): # get op info. input_desc_0 = expand_info['input_desc'][0] input_desc_1 = expand_info['input_desc'][1] - attrs = expand_info['attr'] - axis = None - for item in attrs: - if 'axis' in item: - axis = item['axis'] + axis = expand_info['attr']['axis'] graph_builder = builder.GraphBuilder() if isinstance(axis, int): diff --git a/mindspore/_extends/graph_kernel/expanders/maximum_grad.py b/mindspore/_extends/graph_kernel/expanders/maximum_grad.py index 1625c5976c..e98c2100e2 100644 --- a/mindspore/_extends/graph_kernel/expanders/maximum_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/maximum_grad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,16 +23,10 @@ def expand_maximumgrad(expand_info): input_desc_1 = expand_info['input_desc'][1] input_desc_2 = expand_info['input_desc'][2] attrs = expand_info['attr'] - grad_x = None - grad_y = None - for item in attrs: - if 'grad_x' in item: - grad_x = item['grad_x'] - if 'grad_y' in item: - grad_y = item['grad_y'] - graph_builder = builder.GraphBuilder() + grad_x = attrs['grad_x'] if 'grad_x' in attrs else True + grad_y = attrs['grad_y'] if 'grad_y' in attrs else True - # generate a graph. + graph_builder = builder.GraphBuilder() with graph_builder.graph_scope('main') as graph_scope: # create tensor input. input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) diff --git a/mindspore/_extends/graph_kernel/expanders/minimum_grad.py b/mindspore/_extends/graph_kernel/expanders/minimum_grad.py index 365ca478f5..8b77aa03c3 100644 --- a/mindspore/_extends/graph_kernel/expanders/minimum_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/minimum_grad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,15 +23,10 @@ def expand_minimumgrad(expand_info): input_desc_1 = expand_info['input_desc'][1] input_desc_2 = expand_info['input_desc'][2] attrs = expand_info['attr'] - grad_x = None - grad_y = None - for item in attrs: - if 'grad_x' in item: - grad_x = item['grad_x'] - if 'grad_y' in item: - grad_y = item['grad_y'] + grad_x = attrs['grad_x'] if 'grad_x' in attrs else True + grad_y = attrs['grad_y'] if 'grad_y' in attrs else True + graph_builder = builder.GraphBuilder() - # generate a graph. with graph_builder.graph_scope('main') as graph_scope: # create tensor input. input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) diff --git a/mindspore/_extends/graph_kernel/expanders/reduce_mean.py b/mindspore/_extends/graph_kernel/expanders/reduce_mean.py index 3e9fd6b5e3..14aa6577df 100644 --- a/mindspore/_extends/graph_kernel/expanders/reduce_mean.py +++ b/mindspore/_extends/graph_kernel/expanders/reduce_mean.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,20 +18,13 @@ from mindspore._extends.graph_kernel.model import model_builder as builder def expand_reducemean(expand_info): """ReduceMean expander""" - # get op info. input_desc = expand_info['input_desc'][0] attrs = expand_info['attr'] - axis = None - keep_dims = None - for item in attrs: - if 'axis' in item: - axis = item['axis'] - if 'keep_dims' in item: - keep_dims = item['keep_dims'] - graph_builder = builder.GraphBuilder() + axis = attrs['axis'] + keep_dims = attrs['keep_dims'] - # generate a graph. + graph_builder = builder.GraphBuilder() with graph_builder.graph_scope('main') as graph_scope: # create tensor input. input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) @@ -49,7 +42,7 @@ def expand_reducemean(expand_info): for idx in axis: all_shape *= x_shape[idx] - all_shape_value = graph_builder.value(input_x.dtype, all_shape, input_x.data_format) + all_shape_value = graph_builder.value(input_x.dtype, all_shape) if not axis: sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims}) diff --git a/mindspore/_extends/graph_kernel/expanders/softmax.py b/mindspore/_extends/graph_kernel/expanders/softmax.py index 58a6a2eed6..3696ff8c02 100644 --- a/mindspore/_extends/graph_kernel/expanders/softmax.py +++ b/mindspore/_extends/graph_kernel/expanders/softmax.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,16 +18,10 @@ from mindspore._extends.graph_kernel.model import model_builder as builder def expand_softmax(expand_info): """Softmax expander""" - # get op info. input_desc = expand_info['input_desc'][0] - attrs = expand_info['attr'] - axis = None - for item in attrs: - if 'axis' in item: - axis = item['axis'] - graph_builder = builder.GraphBuilder() + axis = expand_info['attr']['axis'] - # generate a graph. + graph_builder = builder.GraphBuilder() with graph_builder.graph_scope('main') as graph_scope: # create tensor input. input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) diff --git a/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py b/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py index 31ef05f494..d2ec123fdc 100644 --- a/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ def expand_sqrtgrad(expand_info): graph_scope.set_input(input_x, input_dout) # cal result - const_two = graph_builder.value(input_x.dtype, 2, input_x.data_format) + const_two = graph_builder.value(input_x.dtype, 2) dividend = graph_builder.emit('Mul', [input_x, const_two]) result = graph_builder.emit('RealDiv', [input_dout, dividend]) diff --git a/mindspore/_extends/graph_kernel/expanders/tanh_grad.py b/mindspore/_extends/graph_kernel/expanders/tanh_grad.py index 263c6bd767..4e6fc8c326 100644 --- a/mindspore/_extends/graph_kernel/expanders/tanh_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/tanh_grad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ def expand_tanhgrad(expand_info): # create tensor input. input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) - const_one = graph_builder.value(input_y.dtype, ONE, input_y.data_format) + const_one = graph_builder.value(input_y.dtype, ONE) graph_scope.set_input(input_y, input_dy) # cal result diff --git a/mindspore/_extends/graph_kernel/expanders/tile.py b/mindspore/_extends/graph_kernel/expanders/tile.py index 258f25246d..3587bfd7f3 100644 --- a/mindspore/_extends/graph_kernel/expanders/tile.py +++ b/mindspore/_extends/graph_kernel/expanders/tile.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,18 +18,11 @@ from mindspore._extends.graph_kernel.model import model_builder as builder def expand_tile(expand_info): """Tile expander""" - - # get op info. input_desc = expand_info['input_desc'][0] - attrs = expand_info['attr'] - multiples = None - for item in attrs: - if 'multiples' in item: - multiples = item['multiples'] + multiples = expand_info['attr']['multiples'] output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples) - graph_builder = builder.GraphBuilder() - # generate a graph. + graph_builder = builder.GraphBuilder() with graph_builder.graph_scope('main') as graph_scope: # create tensor input. input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) diff --git a/mindspore/_extends/graph_kernel/model/model_builder.py b/mindspore/_extends/graph_kernel/model/model_builder.py index 0b8c283bbc..cdcd4dc92e 100644 --- a/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/mindspore/_extends/graph_kernel/model/model_builder.py @@ -15,7 +15,7 @@ """GraphKernel model builder""" import copy -from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy +from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy, DataFormat def get_tile_output_shape(shape, multiples): @@ -70,7 +70,7 @@ class OpInfer: real_shape = [] for i, _ in enumerate(shape): - if i not in attrs['reduce_axis']: + if i not in attrs['reduce_axis'] and i - len(shape) not in attrs['reduce_axis']: real_shape.append(shape[i]) return real_shape @@ -106,7 +106,15 @@ class OpInfer: @staticmethod def default_infer_format_func(inputs, attrs): """Infer format""" - return inputs[0].data_format + result = inputs[0].data_format + # default_format and other_format results in other_format + for input_tensor in inputs[1:]: + data_format = input_tensor.data_format + if data_format != DataFormat.DEFAULT: + if result not in [DataFormat.DEFAULT, data_format]: + raise RuntimeError("Incompatible data format %s and %s" % (data_format, result)) + result = data_format + return result infer_shape_func = { # add special infer func here @@ -114,13 +122,20 @@ class OpInfer: 'Reshape': lambda inputs, attrs: attrs["shape"], 'BroadcastTo': lambda inputs, attrs: attrs["shape"], 'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0], + 'ExpandDims': lambda inputs, attrs: list(inputs[0].shape).insert(attrs["axis"], 1), } infer_dtype_func = { # add special infer func here 'Cast': lambda inputs, attrs: attrs['dst_type'], + 'Less': lambda inputs, attrs: "bool", + 'LessEqual': lambda inputs, attrs: "bool", + 'Equal': lambda inputs, attrs: "bool", + 'Greater': lambda inputs, attrs: "bool", + 'GreaterEqual': lambda inputs, attrs: "bool", } infer_format_func = { # add special infer func here + 'Reshape': lambda inputs, attrs: "DefaultFormat", } @classmethod @@ -188,18 +203,12 @@ class GraphBuilder: shape = [1] return Tensor(name, shape, dtype, data_format, para_type=para_type) - def value(self, dtype, value, data_format, name=None): + def value(self, dtype, value, name=None): """Create a new Value""" if name in (None, ''): name = self._alloc_tensor_name() - if dtype == "float16": - # For float16 value, it will be changed to float32 wrongly. And there is no good solution for now. - # So instead just declare float32 value and then cast it to float16. - v_fp32 = Value(name, "float32", value, data_format) - v = self.emit("Cast", [v_fp32], attrs={"dst_type": "float16"}) - else: - v = Value(name, dtype, value, data_format) + v = Value(name, dtype, value) return v def op(self, prim, output, inputs, attrs=None): diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc index 7562ed6745..f470ca8a1a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc @@ -19,8 +19,7 @@ #include #include #include -#include -#include +#include #include #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" #include "backend/kernel_compiler/common_utils.h" @@ -46,6 +45,62 @@ namespace { constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; +class AbstractShapeCreator { + public: + using AbstractShapeTransferFunc = std::function; + /** + * Get an abstract shape. + * For a given device_shape and format, the available abstract_shape is not unique, + * this interface only returns a legal abstract_shape without considering padding + * so that the AnfAlgo's get device shape interface can get the right device_shape. + */ + static ShapeVector GetFakeAbstractShape(const ShapeVector &device_shape, const std::string &format) { + const std::map fmap{ + {kOpFormat_NCHW, NchwAbstractShape}, + {kOpFormat_NHWC, NhwcAbstractShape}, + {kOpFormat_FRAC_NZ, FractalNzAbstractShape}, + }; + if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { + return device_shape; + } + auto iter = fmap.find(format); + if (iter == fmap.end()) { + MS_LOG(WARNING) << "Unexpected format[" << format << "]"; + return device_shape; + } + return iter->second(device_shape); + } + + private: + static ShapeVector NchwAbstractShape(const ShapeVector &device_shape) { return device_shape; } + static ShapeVector NhwcAbstractShape(const ShapeVector &device_shape) { + if (device_shape.size() != 4) { + MS_LOG(EXCEPTION) << "Shape size of NHWC should be 4, but got " << device_shape.size(); + } + return {device_shape[0], device_shape[3], device_shape[1], device_shape[2]}; + } + static ShapeVector FractalNzAbstractShape(const ShapeVector &device_shape) { + if (device_shape.size() == 1 && (device_shape[0] == 1 || device_shape[0] % kCubeSize == 0)) { + return device_shape; + } + if (device_shape.size() < 4) { + MS_LOG(EXCEPTION) << "Shape size of FRACTAL_NZ should >= 4, but got " << device_shape.size(); + } + ShapeVector shape; + size_t dims = device_shape.size(); + size_t batch = dims - 4; + for (size_t i = 0; i < batch; ++i) { + shape.push_back(device_shape[i]); + } + int64_t m = device_shape[dims - 3] * device_shape[dims - 2]; + int64_t n = device_shape[dims - 4] * device_shape[dims - 1]; + shape.push_back(m); + shape.push_back(n); + + return shape; + } +}; + class CNodeDecoder { public: explicit CNodeDecoder(std::map *nodes_map) : nodes_map_(*nodes_map) {} @@ -66,6 +121,7 @@ class CNodeDecoder { return nullptr; } CreateKernelInfo(processor); + CreateAbstract(); return cnode_; } @@ -117,12 +173,8 @@ class CNodeDecoder { bool DecodeInputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) { std::string op_name = cnode_json[kJsonKeyName]; - // new primitive. - auto primitive = GetPrimitive(op_name); - if (primitive == nullptr) { - MS_LOG(ERROR) << "Create primitive failed."; - return false; - } + auto primitive = CreatePrimitiveWithAttrs(op_name); + MS_EXCEPTION_IF_NULL(primitive); // collect inputs. auto primitive_v = NewValueNode(primitive); @@ -142,6 +194,7 @@ class CNodeDecoder { } input_formats_.push_back(input_desc[kJsonKeyFormat]); input_types_.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType])); + input_shapes_.push_back(input_desc[kJsonKeyShape]); } // new cnode. cnode_ = func_graph->NewCNode(inputs); @@ -160,6 +213,7 @@ class CNodeDecoder { nlohmann::json output_desc = output_descs[0]; output_formats_.push_back(output_desc[kJsonKeyFormat]); output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); + output_shapes_.push_back(output_desc[kJsonKeyShape]); nodes_map_[output_desc[kJsonKeyTensorName]] = cnode_; } else { // multi outputs. @@ -167,6 +221,7 @@ class CNodeDecoder { nlohmann::json output_desc = output_descs[j]; output_formats_.push_back(output_desc[kJsonKeyFormat]); output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); + output_shapes_.push_back(output_desc[kJsonKeyShape]); auto get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_, NewValueNode(SizeToLong(j))}); func_graph->AddNode(get_item); @@ -219,72 +274,29 @@ class CNodeDecoder { AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode_.get()); } - ValuePtr CreatOpInstance(const std::string &op_name, const std::vector &attrs) { - // python utils. - constexpr auto kGetPythonOpFunc = "_get_python_op"; - constexpr auto kParallelUtilsModule = "mindspore.parallel._utils"; - // almost all ops are defined in this path. - constexpr auto kOperationsModule = "mindspore.ops.operations"; - py::module mod = py::module::import(kOperationsModule); - if (!py::hasattr(mod, op_name.c_str())) { - MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name; - return nullptr; - } - std::vector arg_list; - (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), - [](const ValuePtr &attr) { return ValuePtrToPyData(attr); }); - py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule, - op_name, arg_list); - ValuePtr op_instance = nullptr; - bool succ = parse::ConvertData(obj, &op_instance); - if (!succ) { - MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed."; - return nullptr; - } - return op_instance; + void CreateAbstract() { + auto shape = AbstractShapeCreator::GetFakeAbstractShape(output_shapes_[0], output_formats_[0]); + auto abstract = std::make_shared(TypeIdToType(output_types_[0]), shape); + cnode_->set_abstract(abstract); } - const std::map> op_attrs_map_ = { - {kReduceSumOpName, std::vector{kAttrKeepDims}}, - {kReduceMaxOpName, std::vector{kAttrKeepDims}}, - {kReduceMinOpName, std::vector{kAttrKeepDims}}, - {kBroadcastToOpName, std::vector{kAttrShape}}, - }; - - PrimitivePtr GetPrimitive(const std::string &op_name) { - PrimitivePtr primitive{nullptr}; - if (op_attrs_map_.count(op_name) == 0) { - // no attrs for op instance. - primitive = CreatOpInstance(op_name, std::vector{})->cast(); - } else { - // make attrs for op instance. - std::vector op_attrs; - const auto &attr_names = op_attrs_map_.at(op_name); - for (const auto &attr_name : attr_names) { - if (cnode_attrs_.count(attr_name) == 0) { - MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found."; - return nullptr; - } - op_attrs.push_back(cnode_attrs_.at(attr_name)); - } - primitive = CreatOpInstance(op_name, op_attrs)->cast(); - } - if (primitive != nullptr) { - for (const auto &attr : cnode_attrs_) { - primitive->AddAttr(attr.first, attr.second); - } + PrimitivePtr CreatePrimitiveWithAttrs(const std::string &op_name) { + auto primitive = std::make_shared(op_name); + for (const auto &attr : cnode_attrs_) { + primitive->AddAttr(attr.first, attr.second); } return primitive; } - ScalarPtr DecodeScalar(const nlohmann::json &scalar_json) { + tensor::TensorPtr DecodeScalar(const nlohmann::json &scalar_json) { auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]); switch (type_id) { case kNumberTypeFloat16: + return std::make_shared(static_cast(scalar_json[kJsonKeyValue]), kFloat16); case kNumberTypeFloat32: - return std::make_shared(scalar_json[kJsonKeyValue]); + return std::make_shared(static_cast(scalar_json[kJsonKeyValue]), kFloat32); case kNumberTypeInt32: - return std::make_shared(scalar_json[kJsonKeyValue]); + return std::make_shared(static_cast(scalar_json[kJsonKeyValue]), kInt32); default: MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType]; break; @@ -294,9 +306,8 @@ class CNodeDecoder { ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) { MS_LOG(DEBUG) << "start decode value node, " << value_json; - auto scalar = DecodeScalar(value_json); - auto tensor = ScalarToTensor(scalar); - + auto tensor = DecodeScalar(value_json); + MS_EXCEPTION_IF_NULL(tensor); auto value_node = std::make_shared(tensor); value_node->set_abstract(tensor->ToAbstract()); // create kernel_info fo new value node. @@ -319,6 +330,8 @@ class CNodeDecoder { std::vector output_formats_; std::vector input_types_; std::vector output_types_; + std::vector input_shapes_; + std::vector output_shapes_; CNodePtr cnode_{nullptr}; }; } // namespace @@ -329,11 +342,16 @@ ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶met ParameterPtr new_parameter = func_graph->add_parameter(); std::string name = parameter_json[kJsonKeyTensorName]; new_parameter->set_name(name); + std::string format = parameter_json[kJsonKeyFormat]; + TypeId dtype = DtypeToTypeId(parameter_json[kJsonKeyDataType]); + ShapeVector shape = AbstractShapeCreator::GetFakeAbstractShape(parameter_json[kJsonKeyShape], format); + auto abstract = std::make_shared(TypeIdToType(dtype), shape); + new_parameter->set_abstract(abstract); auto kernel_info = std::make_shared(); new_parameter->set_kernel_info(kernel_info); auto builder = std::make_shared(); - builder->SetOutputsFormat(std::vector{parameter_json[kJsonKeyFormat]}); - builder->SetOutputsDeviceType(std::vector{DtypeToTypeId(parameter_json[kJsonKeyDataType])}); + builder->SetOutputsFormat(std::vector{format}); + builder->SetOutputsDeviceType(std::vector{dtype}); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), new_parameter.get()); nodes_map_[name] = new_parameter; return new_parameter; @@ -349,6 +367,7 @@ CNodePtr AkgKernelJsonDecoder::DecodeCNode(const nlohmann::json &cnode_json, con AnfNodePtr AkgKernelJsonDecoder::DecodeOutput(const std::vector &output_descs, const FuncGraphPtr &func_graph) { std::vector outputs{NewValueNode(prim::kPrimMakeTuple)}; + AbstractBasePtrList output_abstract_list; for (const auto &output_desc : output_descs) { std::string name = output_desc[kJsonKeyTensorName]; if (nodes_map_.count(name) == 0) { @@ -356,11 +375,13 @@ AnfNodePtr AkgKernelJsonDecoder::DecodeOutput(const std::vector return nullptr; } outputs.push_back(nodes_map_[name]); + output_abstract_list.push_back(outputs.back()->abstract()); } if (outputs.size() == 2) { func_graph->set_output(outputs[1]); } else { auto output = func_graph->NewCNode(outputs); + output->set_abstract(std::make_shared(output_abstract_list)); func_graph->AddNode(output); func_graph->set_output(output); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc index 65f0c7d7f0..619802248a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc @@ -80,10 +80,12 @@ class OpInfoExtractor { } void ExtractOutputs(const OpInfoPtr &op_info) { - // only support single output in op desc. - auto io_info = std::make_shared(); - io_info->set_name("output"); - op_info->add_outputs_ptr(io_info); + size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_); + for (size_t i = 0; i < output_tensor_num; i++) { + auto io_info = std::make_shared(); + io_info->set_name("output_" + std::to_string(i)); + op_info->add_outputs_ptr(io_info); + } } bool ExcludeAttr(const std::string &name) { @@ -204,8 +206,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con input_desc_json[kJsonKeyName] = input_ptr->name(); input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); auto input_shape = this->GetInputShape(anf_node, real_input_index); - if (dump_option_.extract_opinfo_from_anfnode && - GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { + if (!is_basic_op_ && GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { MS_LOG(DEBUG) << "Take input[" << real_input_index << "] of [" << anf_node->DebugString(2) << "] as const tensor, shape: [" << Vector2Str(input_shape) << "], value: " << input_desc_json[kJsonKeyValue]; @@ -529,9 +530,9 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(kernel_json); std::string op_name = AnfAlgo::GetCNodeName(anf_node); - MS_LOG(INFO) << "Akg start generate kernel json desc, full scope name is : " << anf_node->fullname_with_scope(); + MS_LOG(DEBUG) << "Akg start generate kernel json desc, full scope name is : " << anf_node->fullname_with_scope(); SetAkgKernelAttrs(anf_node); - dump_option_.extract_opinfo_from_anfnode = false; + is_basic_op_ = true; if (!GenerateSingleKernelJson(anf_node, kernel_json)) { MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed."; return false; @@ -551,8 +552,8 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j return false; } - MS_LOG(INFO) << "Akg create kernel json desc success, full scope name is : " << anf_node->fullname_with_scope() - << ", json info name is : " << kernel_name_; + MS_LOG(DEBUG) << "Akg create kernel json desc success, full scope name is : " << anf_node->fullname_with_scope() + << ", json info name is : " << kernel_name_; return true; } @@ -613,10 +614,11 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf << "]."; return false; } - MS_LOG(INFO) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size() - << "], output_list: [" << input_list.size() << "]."; + MS_LOG(DEBUG) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size() + << "], output_list: [" << input_list.size() << "]."; std::map node_json_map; - dump_option_.extract_opinfo_from_anfnode = true; + is_basic_op_ = false; + dump_option_.extract_opinfo_from_anfnode = true; // always extract from anfnode for composite ops. if (!GenSingleJsons(anf_nodes, &node_json_map)) return false; UpdateTensorName(anf_nodes, &node_json_map); diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h index 0efbc84393..148a54ad9f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h @@ -144,6 +144,7 @@ class AkgKernelJsonGenerator { std::map address_node_map_; std::map> sub_graphs_; std::map dim_infos_; + bool is_basic_op_{false}; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc index 3e4fcc7c52..d81a4120e1 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -393,6 +393,14 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP broadcast_input_node}; auto broadcast_to_node_inner = CreateCNode( atomic_clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_node_)}); + + auto device_shape = AnfAlgo::GetOutputDeviceShape(atomic_add_node_, 0); + dst_shape_vec.clear(); + if (device_shape.empty()) { + dst_shape_vec.push_back(1); + } else { + std::transform(device_shape.begin(), device_shape.end(), std::back_inserter(dst_shape_vec), SizeToLong); + } SetNodeAttrSafely("shape", MakeValue(dst_shape_vec), broadcast_to_node_inner); // Makeup sub-graph. diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index a899e1749e..26e1237dd9 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,88 +36,40 @@ namespace mindspore { namespace opt { namespace { -constexpr auto kJsonKeyExpandInfo = "expand_info"; - -#define GET_VALUE_FOR_JSON(JSON, VALUE, VALUE_ELEM, TYPE_NAME, TYPE) \ - if (VALUE_ELEM->isa()) { \ - JSON = GetValue(VALUE); \ - } - -nlohmann::json ExpandAttrJsonInfo(const CNodePtr &cnode) { - nlohmann::json attrs_json; - if (auto prim = GetCNodePrimitive(cnode); prim != nullptr) { - auto attrs = prim->attrs(); - for (const auto &[k, v] : attrs) { - nlohmann::json attr_json; - MS_LOG(DEBUG) << "attr key is : " << k << " and value type is : " << v->type_name(); - GET_VALUE_FOR_JSON(attr_json[k], v, v, Int32Imm, int); - GET_VALUE_FOR_JSON(attr_json[k], v, v, Int64Imm, int64_t); - GET_VALUE_FOR_JSON(attr_json[k], v, v, UInt32Imm, uint32_t); - GET_VALUE_FOR_JSON(attr_json[k], v, v, UInt64Imm, uint64_t); - GET_VALUE_FOR_JSON(attr_json[k], v, v, FP32Imm, float); - GET_VALUE_FOR_JSON(attr_json[k], v, v, FP64Imm, double); - GET_VALUE_FOR_JSON(attr_json[k], v, v, BoolImm, bool); - GET_VALUE_FOR_JSON(attr_json[k], v, v, StringImm, std::string); - - if (v->isa() || v->isa()) { - auto vec = v->isa() ? v->cast()->value() : v->cast()->value(); - if (!vec.empty()) { - MS_LOG(DEBUG) << "value type is : " << vec[0]->type_name(); - GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], Int32Imm, std::vector); - GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], Int64Imm, std::vector); - GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], UInt32Imm, std::vector); - GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], UInt64Imm, std::vector); - GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], FP32Imm, std::vector); - GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], FP64Imm, std::vector); - GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], StringImm, std::vector); - } - } - if (!attr_json.empty()) { - attrs_json.push_back(attr_json); - } - } - } - return attrs_json; +std::unordered_set GetExpandOps() { + std::unordered_set expand_ops = { + prim::kPrimSquare, + prim::kPrimGeLUGrad, +#if ENABLE_D + prim::kPrimTile, + prim::kPrimSqrtGrad, + prim::kPrimClipByNormNoDivSum, +#elif ENABLE_GPU + prim::kPrimBiasAdd, + prim::kPrimBiasAddGrad, + prim::kPrimGeLU, + prim::kPrimFusedAdam, + prim::kPrimFusedAdamWeightDecay, + prim::kPrimReduceMean, + prim::kPrimMaximumGrad, + prim::kPrimMinimumGrad, + prim::kPrimGkDropout, + prim::kPrimDropoutGrad, + prim::kPrimSoftmax, + prim::kPrimLayerNorm, + prim::kPrimLayerNormGrad, +#endif + }; + return expand_ops; } +} // namespace -bool ExpandJsonInfo(const CNodePtr &cnode, nlohmann::json *kernel_json) { - MS_EXCEPTION_IF_NULL(kernel_json); - if (kernel_json->find(kJsonKeyExpandInfo) != kernel_json->end()) { - return false; - } - - nlohmann::json expand_info; - expand_info[kernel::kJsonKeyAttr] = ExpandAttrJsonInfo(cnode); - expand_info[kernel::kJsonKeyName] = AnfAlgo::GetCNodeName(cnode); - expand_info[kernel::kJsonKeyProcess] = kernel::GetProcessorStr(cnode); - std::vector inputs_info; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) { - nlohmann::json input_info; - input_info[kernel::kJsonKeyFormat] = AnfAlgo::GetInputFormat(cnode, i); - input_info[kernel::kJsonKeyInferShape] = AnfAlgo::GetPrevNodeOutputInferShape(cnode, i); - input_info[kernel::kJsonKeyShape] = AnfAlgo::GetInputDeviceShape(cnode, i); - input_info[kernel::kJsonKeyInferDataType] = - kernel::TypeId2String(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, i)); - input_info[kernel::kJsonKeyDataType] = kernel::TypeId2String(AnfAlgo::GetInputDeviceDataType(cnode, i)); - inputs_info.push_back(input_info); - } - expand_info[kernel::kJsonKeyInputDesc] = inputs_info; - - std::vector outputs_info; - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(cnode); ++i) { - nlohmann::json output_info; - output_info[kernel::kJsonKeyFormat] = AnfAlgo::GetOutputFormat(cnode, i); - output_info[kernel::kJsonKeyInferShape] = AnfAlgo::GetOutputInferShape(cnode, i); - output_info[kernel::kJsonKeyShape] = AnfAlgo::GetOutputDeviceShape(cnode, i); - output_info[kernel::kJsonKeyInferDataType] = kernel::TypeId2String(AnfAlgo::GetOutputInferDataType(cnode, i)); - output_info[kernel::kJsonKeyDataType] = kernel::TypeId2String(AnfAlgo::GetOutputDeviceDataType(cnode, i)); - outputs_info.push_back(output_info); - } - expand_info[kernel::kJsonKeyOutputDesc] = outputs_info; - (*kernel_json)[kJsonKeyExpandInfo] = expand_info; - return true; +bool GraphKernelExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) { + DumpOption dump_option; + dump_option.extract_opinfo_from_anfnode = true; + kernel::AkgKernelJsonGenerator json_generator(dump_option); + return json_generator.CollectJson(node, kernel_json); } -} // namespace FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { nlohmann::json kernel_json; @@ -213,33 +165,11 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { // replace origin node. (void)mng->Replace(node, graph_kernel_node); - - ToPrimitive(AnfAlgo::GetCNodeFuncGraphPtr(graph_kernel_node)); changed = true; } return changed; } -void GraphKernelExpander::ToPrimitive(const FuncGraphPtr &func_graph) const { - auto todos = TopoSort(func_graph->get_return()); - std::reverse(todos.begin(), todos.end()); - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - for (const auto &n : todos) { - auto cnode = n->cast(); - if (cnode == nullptr) { - continue; - } - - auto origin_prim = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(origin_prim); - if (!origin_prim->isa()) { - continue; - } - cnode->set_input(0, std::make_shared(std::make_shared(*origin_prim))); - } -} - bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { expand_ops_ = GetExpandOps(); MS_EXCEPTION_IF_NULL(func_graph); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h index eda740f056..ab153fafdb 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ #include #include +#include #include "backend/optimizer/common/pass.h" #include "ir/func_graph.h" @@ -31,7 +32,6 @@ class GraphKernelExpander : public Pass { private: FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); bool DoExpand(const FuncGraphPtr &func_graph); - void ToPrimitive(const FuncGraphPtr &func_graph) const; void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph, const CNodePtr &node); @@ -39,6 +39,7 @@ class GraphKernelExpander : public Pass { return std::any_of(expand_ops_.begin(), expand_ops_.end(), [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); } + bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json); private: std::unordered_set expand_ops_; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index fe5377f3b3..ba90ef88e7 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -39,27 +39,6 @@ namespace mindspore { namespace opt { namespace { -void DebugDump(const FuncGraphPtr &graph, std::stringstream *buf) { - (*buf) << "Parameters: \n"; - const auto ¶meters = graph->parameters(); - (*buf) << "size: " << parameters.size() << "\n"; - for (const auto &p : parameters) { - (*buf) << "\t" << p->DebugString(2) << "\n"; - } - (*buf) << "ValueNodes: \n"; - const auto &value_nodes = graph->value_nodes(); - (*buf) << "size: " << value_nodes.size() << "\n"; - for (const auto &v : value_nodes) { - (*buf) << "\t" << v.first->DebugString(2) << "\n"; - } - (*buf) << "CNodes: \n"; - const auto &all_nodes = graph->nodes(); - (*buf) << "size: " << all_nodes.size() << "\n"; - for (const auto &n : all_nodes) { - (*buf) << "\t" << n->DebugString(2) << "\n"; - } -} - bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) { MS_EXCEPTION_IF_NULL(real_outs); if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { @@ -91,132 +70,6 @@ AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) { return out_spec; } -ValueNodePtr ProcessAttrsForCast(const CNodePtr &cnode, const std::string &attr_name) { - auto dst_type = AnfAlgo::GetNodeAttr(cnode, attr_name); - auto type = TypeIdToType(kernel::DtypeToTypeId(dst_type)); - auto type_val_node = NewValueNode(type); - return type_val_node; -} - -const std::map> - attrs_process_map = { - {kCastOpName, ProcessAttrsForCast}, -}; - -ValueNodePtr ProcessAttrValue(const CNodePtr &cnode, const std::string &attr_name) { - auto op_name = AnfAlgo::GetCNodeName(cnode); - if (attrs_process_map.count(op_name) != 0) { - return attrs_process_map.at(op_name)(cnode, attr_name); - } - - auto attr_val = AnfAlgo::GetNodeAttr(cnode, attr_name); - auto attr_val_node = NewValueNode(attr_val); - return attr_val_node; -} - -AnfNodePtr ConstAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::unordered_set &input_attrs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(cnode); - MS_LOG(DEBUG) << "process node: " << cnode->DebugString(2); - if (input_attrs.empty()) { - return nullptr; - } - - auto input_names = AnfAlgo::GetNodeAttr>(cnode, kAttrInputNames); - MS_LOG(DEBUG) << "ori_input_names: " << kernel::Vector2Str(input_names); - std::vector new_inputs; - std::vector new_input_names; - const auto &inputs = cnode->inputs(); - for (size_t i = 0; i < inputs.size() - 1; ++i) { - new_input_names.push_back(input_names[i]); - } - - (void)new_inputs.insert(new_inputs.end(), inputs.begin(), inputs.end()); - bool need_update = false; - for (size_t i = inputs.size() - 1; i < input_names.size(); ++i) { - auto attr_name = input_names[i]; - if (input_attrs.find(i) == input_attrs.end()) { - MS_LOG(WARNING) << "Other type input between tensors and attrs, name: " << attr_name - << ", node: " << cnode->DebugString(2); - new_input_names.push_back(attr_name); - continue; - } - if (!AnfAlgo::HasNodeAttr(attr_name, cnode)) { - MS_LOG(EXCEPTION) << "Attr: " << attr_name << " not found in node: " << cnode->DebugString(2); - } - - // Hardcode. It should convert attrs value according to format, like op ReduceSum. - auto attr_val_node = ProcessAttrValue(cnode, attr_name); - new_inputs.push_back(attr_val_node); - new_input_names.push_back(attr_name); - need_update = true; - MS_LOG(DEBUG) << "convert attr: " << attr_name << " to input, value: " << attr_val_node; - } - MS_LOG(DEBUG) << "new_input_names: " << kernel::Vector2Str(new_input_names); - - if (!need_update) { - return nullptr; - } - - auto new_cnode = func_graph->NewCNode(new_inputs); - // we do not modify abstract and kernel info. - new_cnode->set_abstract(cnode->abstract()); - new_cnode->set_kernel_info(cnode->kernel_info_ptr()); - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(new_input_names), new_cnode); - return new_cnode; -} - -AnfNodePtr DeleteAttrInInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::unordered_set &input_attrs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(cnode); - MS_LOG(DEBUG) << "process node: " << cnode->DebugString(2); - if (input_attrs.empty()) { - return nullptr; - } - - auto input_names = AnfAlgo::GetNodeAttr>(cnode, kAttrInputNames); - MS_LOG(DEBUG) << "ori_input_names: " << kernel::Vector2Str(input_names); - std::vector new_inputs; - std::vector new_input_names; - - const auto &inputs = cnode->inputs(); - new_inputs.push_back(inputs[0]); - bool need_update = false; - for (size_t i = 0; i < inputs.size() - 1; ++i) { - auto input_node = inputs[i + 1]; - MS_EXCEPTION_IF_NULL(input_node); - // The attrs counts from 0 - if (input_attrs.find(i) != input_attrs.end() && input_node->isa()) { - auto value_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - MS_LOG(DEBUG) << "delete attr input: " << i << " of node: " << cnode->DebugString(2); - if (i >= input_names.size()) { - MS_LOG(EXCEPTION) << "Index " << i << " is larger than input names size: " << input_names.size(); - } - need_update = true; - } else { - new_inputs.push_back(input_node); - if (i < input_names.size()) { - new_input_names.push_back(input_names[i]); - } - } - } - MS_LOG(DEBUG) << "new_input_names: " << kernel::Vector2Str(new_input_names); - - if (!need_update) { - return nullptr; - } - - auto new_cnode = func_graph->NewCNode(new_inputs); - // we do not modify abstract and kernel info. - new_cnode->set_abstract(cnode->abstract()); - new_cnode->set_kernel_info(cnode->kernel_info_ptr()); - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(new_input_names), new_cnode); - return new_cnode; -} - AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { AnfNodePtrList outs; auto out_node = fg->output(); @@ -396,59 +249,6 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get()); } -void ConstAttrToInput(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - std::vector todos; - kernel::GetValidKernelNodes(func_graph, &todos); - for (const auto &node : todos) { - ConstInputToAttrInfoRegister reg; - if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(node), ®)) { - continue; - } - auto new_node = ConstAttrToInput(func_graph, node->cast(), reg.GetConstInputAttrInfo()); - if (new_node != nullptr && new_node != node) { - mng->Replace(node, new_node); - } - } -} - -void DeleteAttrInInput(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - std::vector todos; - kernel::GetValidKernelNodes(func_graph, &todos); - for (const auto &node : todos) { - ConstInputToAttrInfoRegister reg; - if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(node), ®)) { - continue; - } - auto new_node = DeleteAttrInInput(func_graph, node->cast(), reg.GetConstInputAttrInfo()); - if (new_node != nullptr && new_node != node) { - mng->Replace(node, new_node); - } - } -} - -AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) { - AnfNodePtrList res; - if (outs.size() <= 1) { - return outs; - } - - for (auto out : outs) { - AnfNodePtrList real_outs; - if (IsMakeTupleOut(out, &real_outs)) { - res.insert(res.end(), real_outs.begin(), real_outs.end()); - continue; - } - res.push_back(out); - } - return res; -} - AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs) { auto func_node = NewValueNode(fg); @@ -661,68 +461,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector(); - auto mng = resource->manager(); - MS_EXCEPTION_IF_NULL(mng); - mng->AddFuncGraph(fg); - ConstAttrToInput(fg); - std::stringstream buf; - buf << "===================== graph after ConstAttrToInput " << fg->ToString() << " =====================\n"; - DebugDump(fg, &buf); - MS_LOG(DEBUG) << buf.str(); - - // Do infer and specialize. - AbstractBasePtrList args_spec_list; - std::for_each(inputs.begin(), inputs.end(), - [&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); }); - auto infer_fg = pipeline::Renormalize(resource, fg, args_spec_list); - if (infer_fg == nullptr) { - MS_LOG(ERROR) << "Infer decoded graph failed."; - return nullptr; - } - buf.str(""); - buf << "===================== graph after Renormalize " << infer_fg->ToString() << " =====================\n"; - DebugDump(infer_fg, &buf); - MS_LOG(DEBUG) << buf.str(); - - // delete no use inputs(attrs), like op ReduceSum(axis). - DeleteAttrInInput(infer_fg); - buf.str(""); - buf << "===================== graph after DeleteAttrInInput " << infer_fg->ToString() << " =====================\n"; - DebugDump(infer_fg, &buf); - MS_LOG(DEBUG) << buf.str(); - - // clone a new graph. - auto new_fg = TransformableClone(infer_fg, std::make_shared("akg_decode")); - return new_fg; -} - -std::unordered_set GetExpandOps() { - std::unordered_set expand_ops = { - prim::kPrimSquare, - prim::kPrimGeLUGrad, -#if ENABLE_D - prim::kPrimTile, - prim::kPrimSqrtGrad, - prim::kPrimClipByNormNoDivSum, -#elif ENABLE_GPU - prim::kPrimBiasAdd, - prim::kPrimBiasAddGrad, - prim::kPrimGeLU, - prim::kPrimFusedAdam, - prim::kPrimFusedAdamWeightDecay, - prim::kPrimReduceMean, - prim::kPrimMaximumGrad, - prim::kPrimMinimumGrad, - prim::kPrimGkDropout, - prim::kPrimDropoutGrad, - prim::kPrimSoftmax, - prim::kPrimLayerNorm, - prim::kPrimLayerNormGrad, -#endif - }; - return expand_ops; + return fg; } std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index cbaf475613..b36460971c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -61,7 +61,6 @@ std::tuple MixedNodesTransToGraph( AnfNodePtrList *src_outputs = nullptr); void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, kernel::Processor processor); -AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs); AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs); void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, @@ -74,7 +73,6 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n std::map *address_node_map); bool AnfToJsonDesc(const std::vector &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector &inputs); -std::unordered_set GetExpandOps(); std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); std::vector GetFusibleOpList(); bool IsBasicFuseOp(const AnfNodePtr &node);