From: @dayschan Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_doutags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -20,18 +20,31 @@ from mindspore import log as logger | |||||
| import mindspore._extends.graph_kernel.expanders as expanders | import mindspore._extends.graph_kernel.expanders as expanders | ||||
| def extract_expand_info(kernel_info): | |||||
| """Convert the json into a more friendly format""" | |||||
| input_desc = [] | |||||
| if 'input_desc' in kernel_info and kernel_info['input_desc']: | |||||
| for desc in kernel_info['input_desc']: | |||||
| input_desc += desc | |||||
| attrs = {} | |||||
| if 'attr' in kernel_info and kernel_info['attr']: | |||||
| for attr in kernel_info["attr"]: | |||||
| attrs[attr["name"]] = attr["value"] | |||||
| expand_info = { | |||||
| "name": kernel_info["name"], | |||||
| "input_desc": input_desc, | |||||
| "output_desc": kernel_info["output_desc"], | |||||
| "attr": attrs, | |||||
| "process": kernel_info["process"], | |||||
| } | |||||
| return expand_info | |||||
| def get_op_expander(json_str: str): | def get_op_expander(json_str: str): | ||||
| """get op expander by json info""" | """get op expander by json info""" | ||||
| try: | try: | ||||
| kernel_info = json.loads(json_str) | kernel_info = json.loads(json_str) | ||||
| expand_info = kernel_info['expand_info'] | |||||
| if 'name' not in expand_info: | |||||
| logger.error("expand info have no op name") | |||||
| return None | |||||
| if 'process' not in expand_info: | |||||
| logger.error("expand info have no processor info") | |||||
| return None | |||||
| expand_info = extract_expand_info(kernel_info) | |||||
| processor = expand_info['process'] | processor = expand_info['process'] | ||||
| op_name = str(expand_info['name']).lower() | op_name = str(expand_info['name']).lower() | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -21,20 +21,15 @@ def expand_dropoutgrad(expand_info): | |||||
| # get op info. | # get op info. | ||||
| dy_desc = expand_info['input_desc'][0] | dy_desc = expand_info['input_desc'][0] | ||||
| mask_desc = expand_info['input_desc'][1] | mask_desc = expand_info['input_desc'][1] | ||||
| keep_prob = None | |||||
| for attr in expand_info['attr']: | |||||
| if 'keep_prob' in attr: | |||||
| keep_prob = attr['keep_prob'] | |||||
| if keep_prob is None: | |||||
| raise RuntimeError("keep_prob does not exist in attrs.") | |||||
| # generate a graph. | |||||
| keep_prob = expand_info['attr']['keep_prob'] | |||||
| graph_builder = builder.GraphBuilder() | graph_builder = builder.GraphBuilder() | ||||
| with graph_builder.graph_scope('main') as graph_scope: | with graph_builder.graph_scope('main') as graph_scope: | ||||
| # create tensor input. | # create tensor input. | ||||
| input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) | input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) | ||||
| input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format']) | input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format']) | ||||
| graph_scope.set_input(input_dy, input_mask) | graph_scope.set_input(input_dy, input_mask) | ||||
| r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob, "DefaultFormat") | |||||
| r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob) | |||||
| # create op. | # create op. | ||||
| result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) | result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) | ||||
| result = graph_builder.emit('Mul', [result, input_mask]) | result = graph_builder.emit('Mul', [result, input_mask]) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -40,17 +40,16 @@ def expand_gelu(expand_info): | |||||
| # cal y | # cal y | ||||
| mul_0 = graph_builder.emit('Mul', [input_x, input_x]) | mul_0 = graph_builder.emit('Mul', [input_x, input_x]) | ||||
| pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) | pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) | ||||
| const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format']) | |||||
| const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE) | |||||
| mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) | mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) | ||||
| tanh_res = graph_builder.emit('Add', [input_x, mul_1]) | tanh_res = graph_builder.emit('Add', [input_x, mul_1]) | ||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value( | |||||
| tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format']) | |||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI) | |||||
| y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) | y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) | ||||
| # cal gelu(x) | # cal gelu(x) | ||||
| tanh_y = graph_builder.emit('Tanh', [y]) | tanh_y = graph_builder.emit('Tanh', [y]) | ||||
| const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format']) | |||||
| const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format']) | |||||
| const_one = graph_builder.value(tanh_y.dtype, ONE) | |||||
| const_half = graph_builder.value(tanh_y.dtype, HALF) | |||||
| tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) | tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) | ||||
| mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | ||||
| result = graph_builder.emit('Mul', [const_half, mul_x]) | result = graph_builder.emit('Mul', [const_half, mul_x]) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -45,12 +45,11 @@ def expand_gelugrad(expand_info): | |||||
| graph_scope.set_input(input_dy, input_x, input_y) | graph_scope.set_input(input_dy, input_x, input_y) | ||||
| # create some const var | # create some const var | ||||
| const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format']) | |||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value( | |||||
| input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc_0['format']) | |||||
| const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI, input_desc_0['format']) | |||||
| const_one = graph_builder.value(input_dy.dtype, ONE, input_desc_0['format']) | |||||
| const_half = graph_builder.value(input_dy.dtype, HALF, input_desc_0['format']) | |||||
| const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE) | |||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI) | |||||
| const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI) | |||||
| const_one = graph_builder.value(input_dy.dtype, ONE) | |||||
| const_half = graph_builder.value(input_dy.dtype, HALF) | |||||
| # cal mul_right | # cal mul_right | ||||
| mul_double = graph_builder.emit('Mul', [input_x, input_x]) | mul_double = graph_builder.emit('Mul', [input_x, input_x]) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -21,25 +21,20 @@ def expand_gkdropout(expand_info): | |||||
| # get op info. | # get op info. | ||||
| input_desc = expand_info['input_desc'][0] | input_desc = expand_info['input_desc'][0] | ||||
| maks_desc = expand_info['input_desc'][1] | maks_desc = expand_info['input_desc'][1] | ||||
| keep_prob = None | |||||
| for attr in expand_info['attr']: | |||||
| if 'keep_prob' in attr: | |||||
| keep_prob = attr['keep_prob'] | |||||
| if keep_prob is None: | |||||
| raise RuntimeError("keep_prob does not exist in attrs.") | |||||
| # generate a graph. | |||||
| keep_prob = expand_info['attr']['keep_prob'] | |||||
| graph_builder = builder.GraphBuilder() | graph_builder = builder.GraphBuilder() | ||||
| with graph_builder.graph_scope('main') as graph_scope: | with graph_builder.graph_scope('main') as graph_scope: | ||||
| # create tensor input. | # create tensor input. | ||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | ||||
| input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format']) | input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format']) | ||||
| graph_scope.set_input(input_x, input_mask) | graph_scope.set_input(input_x, input_mask) | ||||
| keep_prob_v = graph_builder.value(input_x.dtype, keep_prob, "DefaultFormat") | |||||
| r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob, "DefaultFormat") | |||||
| keep_prob_v = graph_builder.value(input_x.dtype, keep_prob) | |||||
| r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob) | |||||
| if input_mask.dtype != input_x.dtype: | if input_mask.dtype != input_x.dtype: | ||||
| input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype}) | input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype}) | ||||
| mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type | |||||
| mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type | |||||
| mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype}) | mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype}) | ||||
| # compute result | # compute result | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -23,16 +23,10 @@ def expand_layernorm(expand_info): | |||||
| input_desc_1 = expand_info['input_desc'][1] | input_desc_1 = expand_info['input_desc'][1] | ||||
| input_desc_2 = expand_info['input_desc'][2] | input_desc_2 = expand_info['input_desc'][2] | ||||
| attrs = expand_info['attr'] | attrs = expand_info['attr'] | ||||
| begin_norm_axis = None | |||||
| epsilon = None | |||||
| for item in attrs: | |||||
| if 'begin_norm_axis' in item: | |||||
| begin_norm_axis = item['begin_norm_axis'] | |||||
| if 'epsilon' in item: | |||||
| epsilon = item['epsilon'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| begin_norm_axis = attrs['begin_norm_axis'] | |||||
| epsilon = attrs['epsilon'] | |||||
| # generate a graph. | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | with graph_builder.graph_scope('main') as graph_scope: | ||||
| # create tensor input. | # create tensor input. | ||||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | ||||
| @@ -52,7 +46,7 @@ def expand_layernorm(expand_info): | |||||
| for i in reduce_axis: | for i in reduce_axis: | ||||
| reduce_elts *= shape_x[i] | reduce_elts *= shape_x[i] | ||||
| mean_cof = 1.0 / reduce_elts | mean_cof = 1.0 / reduce_elts | ||||
| mean_cof_v = graph_builder.value(input_x.dtype, mean_cof, input_x.data_format) | |||||
| mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) | |||||
| # Calculate mean | # Calculate mean | ||||
| mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) | mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) | ||||
| @@ -67,7 +61,7 @@ def expand_layernorm(expand_info): | |||||
| # Calculate normalize | # Calculate normalize | ||||
| normalize_sub = graph_builder.emit('Sub', [input_x, mean]) | normalize_sub = graph_builder.emit('Sub', [input_x, mean]) | ||||
| epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format) | |||||
| epsilon_v = graph_builder.value(input_x.dtype, epsilon) | |||||
| normalize_add = graph_builder.emit('Add', [variance, epsilon_v]) | normalize_add = graph_builder.emit('Add', [variance, epsilon_v]) | ||||
| normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add]) | normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add]) | ||||
| normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) | normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -24,16 +24,10 @@ def expand_layernormgrad(expand_info): | |||||
| var_desc = expand_info['input_desc'][2] | var_desc = expand_info['input_desc'][2] | ||||
| mean_desc = expand_info['input_desc'][3] | mean_desc = expand_info['input_desc'][3] | ||||
| gamma_desc = expand_info['input_desc'][4] | gamma_desc = expand_info['input_desc'][4] | ||||
| begin_norm_axis = None | |||||
| begin_params_axis = None | |||||
| epsilon = 1e-11 | |||||
| for item in expand_info['attr']: | |||||
| if 'begin_norm_axis' in item: | |||||
| begin_norm_axis = item['begin_norm_axis'] | |||||
| if 'begin_params_axis' in item: | |||||
| begin_params_axis = item['begin_params_axis'] | |||||
| if 'epsilon' in item: | |||||
| epsilon = item['epsilon'] | |||||
| attrs = expand_info['attr'] | |||||
| begin_norm_axis = attrs['begin_norm_axis'] | |||||
| begin_params_axis = attrs['begin_params_axis'] | |||||
| epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-11 | |||||
| shape_x = x_desc['shape'] | shape_x = x_desc['shape'] | ||||
| if begin_norm_axis < 0: | if begin_norm_axis < 0: | ||||
| @@ -57,13 +51,13 @@ def expand_layernormgrad(expand_info): | |||||
| graph_scope.set_input(x, dy, variance, mean, gamma) | graph_scope.set_input(x, dy, variance, mean, gamma) | ||||
| # set some constant val. | # set some constant val. | ||||
| eps = graph_builder.value(x.dtype, epsilon, x.data_format) | |||||
| const_one = graph_builder.value(x.dtype, 1.0, x.data_format) | |||||
| const_neg_half = graph_builder.value(x.dtype, -0.5, x.data_format) | |||||
| const_neg_two = graph_builder.value(x.dtype, -2.0, x.data_format) | |||||
| const_two = graph_builder.value(x.dtype, 2.0, x.data_format) | |||||
| const_neg_one = graph_builder.value(x.dtype, -1.0, x.data_format) | |||||
| mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format) | |||||
| eps = graph_builder.value(x.dtype, epsilon) | |||||
| const_one = graph_builder.value(x.dtype, 1.0) | |||||
| const_neg_half = graph_builder.value(x.dtype, -0.5) | |||||
| const_neg_two = graph_builder.value(x.dtype, -2.0) | |||||
| const_two = graph_builder.value(x.dtype, 2.0) | |||||
| const_neg_one = graph_builder.value(x.dtype, -1.0) | |||||
| mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size)) | |||||
| # cal dg db | # cal dg db | ||||
| var_eps = graph_builder.emit('Add', [variance, eps]) | var_eps = graph_builder.emit('Add', [variance, eps]) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -20,11 +20,7 @@ def expand_logsoftmax(expand_info): | |||||
| """LogSoftmax expander""" | """LogSoftmax expander""" | ||||
| # get op info. | # get op info. | ||||
| input_desc = expand_info['input_desc'][0] | input_desc = expand_info['input_desc'][0] | ||||
| attrs = expand_info['attr'] | |||||
| axis = None | |||||
| for item in attrs: | |||||
| if 'axis' in item: | |||||
| axis = item['axis'] | |||||
| axis = expand_info['attr']['axis'] | |||||
| graph_builder = builder.GraphBuilder() | graph_builder = builder.GraphBuilder() | ||||
| if isinstance(axis, int): | if isinstance(axis, int): | ||||
| axis = (axis,) | axis = (axis,) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -21,11 +21,7 @@ def expand_logsoftmaxgrad(expand_info): | |||||
| # get op info. | # get op info. | ||||
| input_desc_0 = expand_info['input_desc'][0] | input_desc_0 = expand_info['input_desc'][0] | ||||
| input_desc_1 = expand_info['input_desc'][1] | input_desc_1 = expand_info['input_desc'][1] | ||||
| attrs = expand_info['attr'] | |||||
| axis = None | |||||
| for item in attrs: | |||||
| if 'axis' in item: | |||||
| axis = item['axis'] | |||||
| axis = expand_info['attr']['axis'] | |||||
| graph_builder = builder.GraphBuilder() | graph_builder = builder.GraphBuilder() | ||||
| if isinstance(axis, int): | if isinstance(axis, int): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -23,16 +23,10 @@ def expand_maximumgrad(expand_info): | |||||
| input_desc_1 = expand_info['input_desc'][1] | input_desc_1 = expand_info['input_desc'][1] | ||||
| input_desc_2 = expand_info['input_desc'][2] | input_desc_2 = expand_info['input_desc'][2] | ||||
| attrs = expand_info['attr'] | attrs = expand_info['attr'] | ||||
| grad_x = None | |||||
| grad_y = None | |||||
| for item in attrs: | |||||
| if 'grad_x' in item: | |||||
| grad_x = item['grad_x'] | |||||
| if 'grad_y' in item: | |||||
| grad_y = item['grad_y'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| grad_x = attrs['grad_x'] if 'grad_x' in attrs else True | |||||
| grad_y = attrs['grad_y'] if 'grad_y' in attrs else True | |||||
| # generate a graph. | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | with graph_builder.graph_scope('main') as graph_scope: | ||||
| # create tensor input. | # create tensor input. | ||||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -23,15 +23,10 @@ def expand_minimumgrad(expand_info): | |||||
| input_desc_1 = expand_info['input_desc'][1] | input_desc_1 = expand_info['input_desc'][1] | ||||
| input_desc_2 = expand_info['input_desc'][2] | input_desc_2 = expand_info['input_desc'][2] | ||||
| attrs = expand_info['attr'] | attrs = expand_info['attr'] | ||||
| grad_x = None | |||||
| grad_y = None | |||||
| for item in attrs: | |||||
| if 'grad_x' in item: | |||||
| grad_x = item['grad_x'] | |||||
| if 'grad_y' in item: | |||||
| grad_y = item['grad_y'] | |||||
| grad_x = attrs['grad_x'] if 'grad_x' in attrs else True | |||||
| grad_y = attrs['grad_y'] if 'grad_y' in attrs else True | |||||
| graph_builder = builder.GraphBuilder() | graph_builder = builder.GraphBuilder() | ||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | with graph_builder.graph_scope('main') as graph_scope: | ||||
| # create tensor input. | # create tensor input. | ||||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -18,20 +18,13 @@ from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| def expand_reducemean(expand_info): | def expand_reducemean(expand_info): | ||||
| """ReduceMean expander""" | """ReduceMean expander""" | ||||
| # get op info. | # get op info. | ||||
| input_desc = expand_info['input_desc'][0] | input_desc = expand_info['input_desc'][0] | ||||
| attrs = expand_info['attr'] | attrs = expand_info['attr'] | ||||
| axis = None | |||||
| keep_dims = None | |||||
| for item in attrs: | |||||
| if 'axis' in item: | |||||
| axis = item['axis'] | |||||
| if 'keep_dims' in item: | |||||
| keep_dims = item['keep_dims'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| axis = attrs['axis'] | |||||
| keep_dims = attrs['keep_dims'] | |||||
| # generate a graph. | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | with graph_builder.graph_scope('main') as graph_scope: | ||||
| # create tensor input. | # create tensor input. | ||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | ||||
| @@ -49,7 +42,7 @@ def expand_reducemean(expand_info): | |||||
| for idx in axis: | for idx in axis: | ||||
| all_shape *= x_shape[idx] | all_shape *= x_shape[idx] | ||||
| all_shape_value = graph_builder.value(input_x.dtype, all_shape, input_x.data_format) | |||||
| all_shape_value = graph_builder.value(input_x.dtype, all_shape) | |||||
| if not axis: | if not axis: | ||||
| sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims}) | sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims}) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -18,16 +18,10 @@ from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| def expand_softmax(expand_info): | def expand_softmax(expand_info): | ||||
| """Softmax expander""" | """Softmax expander""" | ||||
| # get op info. | |||||
| input_desc = expand_info['input_desc'][0] | input_desc = expand_info['input_desc'][0] | ||||
| attrs = expand_info['attr'] | |||||
| axis = None | |||||
| for item in attrs: | |||||
| if 'axis' in item: | |||||
| axis = item['axis'] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| axis = expand_info['attr']['axis'] | |||||
| # generate a graph. | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | with graph_builder.graph_scope('main') as graph_scope: | ||||
| # create tensor input. | # create tensor input. | ||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -34,7 +34,7 @@ def expand_sqrtgrad(expand_info): | |||||
| graph_scope.set_input(input_x, input_dout) | graph_scope.set_input(input_x, input_dout) | ||||
| # cal result | # cal result | ||||
| const_two = graph_builder.value(input_x.dtype, 2, input_x.data_format) | |||||
| const_two = graph_builder.value(input_x.dtype, 2) | |||||
| dividend = graph_builder.emit('Mul', [input_x, const_two]) | dividend = graph_builder.emit('Mul', [input_x, const_two]) | ||||
| result = graph_builder.emit('RealDiv', [input_dout, dividend]) | result = graph_builder.emit('RealDiv', [input_dout, dividend]) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -31,7 +31,7 @@ def expand_tanhgrad(expand_info): | |||||
| # create tensor input. | # create tensor input. | ||||
| input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | ||||
| input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | ||||
| const_one = graph_builder.value(input_y.dtype, ONE, input_y.data_format) | |||||
| const_one = graph_builder.value(input_y.dtype, ONE) | |||||
| graph_scope.set_input(input_y, input_dy) | graph_scope.set_input(input_y, input_dy) | ||||
| # cal result | # cal result | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -18,18 +18,11 @@ from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| def expand_tile(expand_info): | def expand_tile(expand_info): | ||||
| """Tile expander""" | """Tile expander""" | ||||
| # get op info. | |||||
| input_desc = expand_info['input_desc'][0] | input_desc = expand_info['input_desc'][0] | ||||
| attrs = expand_info['attr'] | |||||
| multiples = None | |||||
| for item in attrs: | |||||
| if 'multiples' in item: | |||||
| multiples = item['multiples'] | |||||
| multiples = expand_info['attr']['multiples'] | |||||
| output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples) | output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples) | ||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | with graph_builder.graph_scope('main') as graph_scope: | ||||
| # create tensor input. | # create tensor input. | ||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | ||||
| @@ -15,7 +15,7 @@ | |||||
| """GraphKernel model builder""" | """GraphKernel model builder""" | ||||
| import copy | import copy | ||||
| from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy | |||||
| from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy, DataFormat | |||||
| def get_tile_output_shape(shape, multiples): | def get_tile_output_shape(shape, multiples): | ||||
| @@ -70,7 +70,7 @@ class OpInfer: | |||||
| real_shape = [] | real_shape = [] | ||||
| for i, _ in enumerate(shape): | for i, _ in enumerate(shape): | ||||
| if i not in attrs['reduce_axis']: | |||||
| if i not in attrs['reduce_axis'] and i - len(shape) not in attrs['reduce_axis']: | |||||
| real_shape.append(shape[i]) | real_shape.append(shape[i]) | ||||
| return real_shape | return real_shape | ||||
| @@ -106,7 +106,15 @@ class OpInfer: | |||||
| @staticmethod | @staticmethod | ||||
| def default_infer_format_func(inputs, attrs): | def default_infer_format_func(inputs, attrs): | ||||
| """Infer format""" | """Infer format""" | ||||
| return inputs[0].data_format | |||||
| result = inputs[0].data_format | |||||
| # default_format and other_format results in other_format | |||||
| for input_tensor in inputs[1:]: | |||||
| data_format = input_tensor.data_format | |||||
| if data_format != DataFormat.DEFAULT: | |||||
| if result not in [DataFormat.DEFAULT, data_format]: | |||||
| raise RuntimeError("Incompatible data format %s and %s" % (data_format, result)) | |||||
| result = data_format | |||||
| return result | |||||
| infer_shape_func = { | infer_shape_func = { | ||||
| # add special infer func here | # add special infer func here | ||||
| @@ -114,13 +122,20 @@ class OpInfer: | |||||
| 'Reshape': lambda inputs, attrs: attrs["shape"], | 'Reshape': lambda inputs, attrs: attrs["shape"], | ||||
| 'BroadcastTo': lambda inputs, attrs: attrs["shape"], | 'BroadcastTo': lambda inputs, attrs: attrs["shape"], | ||||
| 'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0], | 'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0], | ||||
| 'ExpandDims': lambda inputs, attrs: list(inputs[0].shape).insert(attrs["axis"], 1), | |||||
| } | } | ||||
| infer_dtype_func = { | infer_dtype_func = { | ||||
| # add special infer func here | # add special infer func here | ||||
| 'Cast': lambda inputs, attrs: attrs['dst_type'], | 'Cast': lambda inputs, attrs: attrs['dst_type'], | ||||
| 'Less': lambda inputs, attrs: "bool", | |||||
| 'LessEqual': lambda inputs, attrs: "bool", | |||||
| 'Equal': lambda inputs, attrs: "bool", | |||||
| 'Greater': lambda inputs, attrs: "bool", | |||||
| 'GreaterEqual': lambda inputs, attrs: "bool", | |||||
| } | } | ||||
| infer_format_func = { | infer_format_func = { | ||||
| # add special infer func here | # add special infer func here | ||||
| 'Reshape': lambda inputs, attrs: "DefaultFormat", | |||||
| } | } | ||||
| @classmethod | @classmethod | ||||
| @@ -188,18 +203,12 @@ class GraphBuilder: | |||||
| shape = [1] | shape = [1] | ||||
| return Tensor(name, shape, dtype, data_format, para_type=para_type) | return Tensor(name, shape, dtype, data_format, para_type=para_type) | ||||
| def value(self, dtype, value, data_format, name=None): | |||||
| def value(self, dtype, value, name=None): | |||||
| """Create a new Value""" | """Create a new Value""" | ||||
| if name in (None, ''): | if name in (None, ''): | ||||
| name = self._alloc_tensor_name() | name = self._alloc_tensor_name() | ||||
| if dtype == "float16": | |||||
| # For float16 value, it will be changed to float32 wrongly. And there is no good solution for now. | |||||
| # So instead just declare float32 value and then cast it to float16. | |||||
| v_fp32 = Value(name, "float32", value, data_format) | |||||
| v = self.emit("Cast", [v_fp32], attrs={"dst_type": "float16"}) | |||||
| else: | |||||
| v = Value(name, dtype, value, data_format) | |||||
| v = Value(name, dtype, value) | |||||
| return v | return v | ||||
| def op(self, prim, output, inputs, attrs=None): | def op(self, prim, output, inputs, attrs=None): | ||||
| @@ -19,8 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <sstream> | #include <sstream> | ||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | |||||
| #include <unordered_set> | |||||
| #include <map> | |||||
| #include <vector> | #include <vector> | ||||
| #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | ||||
| #include "backend/kernel_compiler/common_utils.h" | #include "backend/kernel_compiler/common_utils.h" | ||||
| @@ -46,6 +45,62 @@ namespace { | |||||
| constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; | constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; | ||||
| constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; | constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; | ||||
| class AbstractShapeCreator { | |||||
| public: | |||||
| using AbstractShapeTransferFunc = std::function<ShapeVector(const ShapeVector &)>; | |||||
| /** | |||||
| * Get an abstract shape. | |||||
| * For a given device_shape and format, the available abstract_shape is not unique, | |||||
| * this interface only returns a legal abstract_shape without considering padding | |||||
| * so that the AnfAlgo's get device shape interface can get the right device_shape. | |||||
| */ | |||||
| static ShapeVector GetFakeAbstractShape(const ShapeVector &device_shape, const std::string &format) { | |||||
| const std::map<std::string, AbstractShapeTransferFunc> fmap{ | |||||
| {kOpFormat_NCHW, NchwAbstractShape}, | |||||
| {kOpFormat_NHWC, NhwcAbstractShape}, | |||||
| {kOpFormat_FRAC_NZ, FractalNzAbstractShape}, | |||||
| }; | |||||
| if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | |||||
| return device_shape; | |||||
| } | |||||
| auto iter = fmap.find(format); | |||||
| if (iter == fmap.end()) { | |||||
| MS_LOG(WARNING) << "Unexpected format[" << format << "]"; | |||||
| return device_shape; | |||||
| } | |||||
| return iter->second(device_shape); | |||||
| } | |||||
| private: | |||||
| static ShapeVector NchwAbstractShape(const ShapeVector &device_shape) { return device_shape; } | |||||
| static ShapeVector NhwcAbstractShape(const ShapeVector &device_shape) { | |||||
| if (device_shape.size() != 4) { | |||||
| MS_LOG(EXCEPTION) << "Shape size of NHWC should be 4, but got " << device_shape.size(); | |||||
| } | |||||
| return {device_shape[0], device_shape[3], device_shape[1], device_shape[2]}; | |||||
| } | |||||
| static ShapeVector FractalNzAbstractShape(const ShapeVector &device_shape) { | |||||
| if (device_shape.size() == 1 && (device_shape[0] == 1 || device_shape[0] % kCubeSize == 0)) { | |||||
| return device_shape; | |||||
| } | |||||
| if (device_shape.size() < 4) { | |||||
| MS_LOG(EXCEPTION) << "Shape size of FRACTAL_NZ should >= 4, but got " << device_shape.size(); | |||||
| } | |||||
| ShapeVector shape; | |||||
| size_t dims = device_shape.size(); | |||||
| size_t batch = dims - 4; | |||||
| for (size_t i = 0; i < batch; ++i) { | |||||
| shape.push_back(device_shape[i]); | |||||
| } | |||||
| int64_t m = device_shape[dims - 3] * device_shape[dims - 2]; | |||||
| int64_t n = device_shape[dims - 4] * device_shape[dims - 1]; | |||||
| shape.push_back(m); | |||||
| shape.push_back(n); | |||||
| return shape; | |||||
| } | |||||
| }; | |||||
| class CNodeDecoder { | class CNodeDecoder { | ||||
| public: | public: | ||||
| explicit CNodeDecoder(std::map<std::string, AnfNodePtr> *nodes_map) : nodes_map_(*nodes_map) {} | explicit CNodeDecoder(std::map<std::string, AnfNodePtr> *nodes_map) : nodes_map_(*nodes_map) {} | ||||
| @@ -66,6 +121,7 @@ class CNodeDecoder { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| CreateKernelInfo(processor); | CreateKernelInfo(processor); | ||||
| CreateAbstract(); | |||||
| return cnode_; | return cnode_; | ||||
| } | } | ||||
| @@ -117,12 +173,8 @@ class CNodeDecoder { | |||||
| bool DecodeInputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) { | bool DecodeInputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) { | ||||
| std::string op_name = cnode_json[kJsonKeyName]; | std::string op_name = cnode_json[kJsonKeyName]; | ||||
| // new primitive. | |||||
| auto primitive = GetPrimitive(op_name); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "Create primitive failed."; | |||||
| return false; | |||||
| } | |||||
| auto primitive = CreatePrimitiveWithAttrs(op_name); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| // collect inputs. | // collect inputs. | ||||
| auto primitive_v = NewValueNode(primitive); | auto primitive_v = NewValueNode(primitive); | ||||
| @@ -142,6 +194,7 @@ class CNodeDecoder { | |||||
| } | } | ||||
| input_formats_.push_back(input_desc[kJsonKeyFormat]); | input_formats_.push_back(input_desc[kJsonKeyFormat]); | ||||
| input_types_.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType])); | input_types_.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType])); | ||||
| input_shapes_.push_back(input_desc[kJsonKeyShape]); | |||||
| } | } | ||||
| // new cnode. | // new cnode. | ||||
| cnode_ = func_graph->NewCNode(inputs); | cnode_ = func_graph->NewCNode(inputs); | ||||
| @@ -160,6 +213,7 @@ class CNodeDecoder { | |||||
| nlohmann::json output_desc = output_descs[0]; | nlohmann::json output_desc = output_descs[0]; | ||||
| output_formats_.push_back(output_desc[kJsonKeyFormat]); | output_formats_.push_back(output_desc[kJsonKeyFormat]); | ||||
| output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); | output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); | ||||
| output_shapes_.push_back(output_desc[kJsonKeyShape]); | |||||
| nodes_map_[output_desc[kJsonKeyTensorName]] = cnode_; | nodes_map_[output_desc[kJsonKeyTensorName]] = cnode_; | ||||
| } else { | } else { | ||||
| // multi outputs. | // multi outputs. | ||||
| @@ -167,6 +221,7 @@ class CNodeDecoder { | |||||
| nlohmann::json output_desc = output_descs[j]; | nlohmann::json output_desc = output_descs[j]; | ||||
| output_formats_.push_back(output_desc[kJsonKeyFormat]); | output_formats_.push_back(output_desc[kJsonKeyFormat]); | ||||
| output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); | output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); | ||||
| output_shapes_.push_back(output_desc[kJsonKeyShape]); | |||||
| auto get_item = | auto get_item = | ||||
| func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_, NewValueNode(SizeToLong(j))}); | func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_, NewValueNode(SizeToLong(j))}); | ||||
| func_graph->AddNode(get_item); | func_graph->AddNode(get_item); | ||||
| @@ -219,72 +274,29 @@ class CNodeDecoder { | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode_.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode_.get()); | ||||
| } | } | ||||
| ValuePtr CreatOpInstance(const std::string &op_name, const std::vector<ValuePtr> &attrs) { | |||||
| // python utils. | |||||
| constexpr auto kGetPythonOpFunc = "_get_python_op"; | |||||
| constexpr auto kParallelUtilsModule = "mindspore.parallel._utils"; | |||||
| // almost all ops are defined in this path. | |||||
| constexpr auto kOperationsModule = "mindspore.ops.operations"; | |||||
| py::module mod = py::module::import(kOperationsModule); | |||||
| if (!py::hasattr(mod, op_name.c_str())) { | |||||
| MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name; | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<py::object> arg_list; | |||||
| (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), | |||||
| [](const ValuePtr &attr) { return ValuePtrToPyData(attr); }); | |||||
| py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule, | |||||
| op_name, arg_list); | |||||
| ValuePtr op_instance = nullptr; | |||||
| bool succ = parse::ConvertData(obj, &op_instance); | |||||
| if (!succ) { | |||||
| MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed."; | |||||
| return nullptr; | |||||
| } | |||||
| return op_instance; | |||||
| void CreateAbstract() { | |||||
| auto shape = AbstractShapeCreator::GetFakeAbstractShape(output_shapes_[0], output_formats_[0]); | |||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(output_types_[0]), shape); | |||||
| cnode_->set_abstract(abstract); | |||||
| } | } | ||||
| const std::map<std::string, std::vector<std::string>> op_attrs_map_ = { | |||||
| {kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}}, | |||||
| {kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}}, | |||||
| {kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}}, | |||||
| {kBroadcastToOpName, std::vector<std::string>{kAttrShape}}, | |||||
| }; | |||||
| PrimitivePtr GetPrimitive(const std::string &op_name) { | |||||
| PrimitivePtr primitive{nullptr}; | |||||
| if (op_attrs_map_.count(op_name) == 0) { | |||||
| // no attrs for op instance. | |||||
| primitive = CreatOpInstance(op_name, std::vector<ValuePtr>{})->cast<PrimitivePtr>(); | |||||
| } else { | |||||
| // make attrs for op instance. | |||||
| std::vector<ValuePtr> op_attrs; | |||||
| const auto &attr_names = op_attrs_map_.at(op_name); | |||||
| for (const auto &attr_name : attr_names) { | |||||
| if (cnode_attrs_.count(attr_name) == 0) { | |||||
| MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found."; | |||||
| return nullptr; | |||||
| } | |||||
| op_attrs.push_back(cnode_attrs_.at(attr_name)); | |||||
| } | |||||
| primitive = CreatOpInstance(op_name, op_attrs)->cast<PrimitivePtr>(); | |||||
| } | |||||
| if (primitive != nullptr) { | |||||
| for (const auto &attr : cnode_attrs_) { | |||||
| primitive->AddAttr(attr.first, attr.second); | |||||
| } | |||||
| PrimitivePtr CreatePrimitiveWithAttrs(const std::string &op_name) { | |||||
| auto primitive = std::make_shared<Primitive>(op_name); | |||||
| for (const auto &attr : cnode_attrs_) { | |||||
| primitive->AddAttr(attr.first, attr.second); | |||||
| } | } | ||||
| return primitive; | return primitive; | ||||
| } | } | ||||
| ScalarPtr DecodeScalar(const nlohmann::json &scalar_json) { | |||||
| tensor::TensorPtr DecodeScalar(const nlohmann::json &scalar_json) { | |||||
| auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]); | auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]); | ||||
| switch (type_id) { | switch (type_id) { | ||||
| case kNumberTypeFloat16: | case kNumberTypeFloat16: | ||||
| return std::make_shared<tensor::Tensor>(static_cast<float>(scalar_json[kJsonKeyValue]), kFloat16); | |||||
| case kNumberTypeFloat32: | case kNumberTypeFloat32: | ||||
| return std::make_shared<FP32Imm>(scalar_json[kJsonKeyValue]); | |||||
| return std::make_shared<tensor::Tensor>(static_cast<float>(scalar_json[kJsonKeyValue]), kFloat32); | |||||
| case kNumberTypeInt32: | case kNumberTypeInt32: | ||||
| return std::make_shared<Int32Imm>(scalar_json[kJsonKeyValue]); | |||||
| return std::make_shared<tensor::Tensor>(static_cast<int64_t>(scalar_json[kJsonKeyValue]), kInt32); | |||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType]; | MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType]; | ||||
| break; | break; | ||||
| @@ -294,9 +306,8 @@ class CNodeDecoder { | |||||
| ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) { | ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) { | ||||
| MS_LOG(DEBUG) << "start decode value node, " << value_json; | MS_LOG(DEBUG) << "start decode value node, " << value_json; | ||||
| auto scalar = DecodeScalar(value_json); | |||||
| auto tensor = ScalarToTensor(scalar); | |||||
| auto tensor = DecodeScalar(value_json); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| auto value_node = std::make_shared<ValueNode>(tensor); | auto value_node = std::make_shared<ValueNode>(tensor); | ||||
| value_node->set_abstract(tensor->ToAbstract()); | value_node->set_abstract(tensor->ToAbstract()); | ||||
| // create kernel_info fo new value node. | // create kernel_info fo new value node. | ||||
| @@ -319,6 +330,8 @@ class CNodeDecoder { | |||||
| std::vector<std::string> output_formats_; | std::vector<std::string> output_formats_; | ||||
| std::vector<TypeId> input_types_; | std::vector<TypeId> input_types_; | ||||
| std::vector<TypeId> output_types_; | std::vector<TypeId> output_types_; | ||||
| std::vector<ShapeVector> input_shapes_; | |||||
| std::vector<ShapeVector> output_shapes_; | |||||
| CNodePtr cnode_{nullptr}; | CNodePtr cnode_{nullptr}; | ||||
| }; | }; | ||||
| } // namespace | } // namespace | ||||
| @@ -329,11 +342,16 @@ ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶met | |||||
| ParameterPtr new_parameter = func_graph->add_parameter(); | ParameterPtr new_parameter = func_graph->add_parameter(); | ||||
| std::string name = parameter_json[kJsonKeyTensorName]; | std::string name = parameter_json[kJsonKeyTensorName]; | ||||
| new_parameter->set_name(name); | new_parameter->set_name(name); | ||||
| std::string format = parameter_json[kJsonKeyFormat]; | |||||
| TypeId dtype = DtypeToTypeId(parameter_json[kJsonKeyDataType]); | |||||
| ShapeVector shape = AbstractShapeCreator::GetFakeAbstractShape(parameter_json[kJsonKeyShape], format); | |||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(dtype), shape); | |||||
| new_parameter->set_abstract(abstract); | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | auto kernel_info = std::make_shared<device::KernelInfo>(); | ||||
| new_parameter->set_kernel_info(kernel_info); | new_parameter->set_kernel_info(kernel_info); | ||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | ||||
| builder->SetOutputsFormat(std::vector<std::string>{parameter_json[kJsonKeyFormat]}); | |||||
| builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(parameter_json[kJsonKeyDataType])}); | |||||
| builder->SetOutputsFormat(std::vector<std::string>{format}); | |||||
| builder->SetOutputsDeviceType(std::vector<TypeId>{dtype}); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), new_parameter.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), new_parameter.get()); | ||||
| nodes_map_[name] = new_parameter; | nodes_map_[name] = new_parameter; | ||||
| return new_parameter; | return new_parameter; | ||||
| @@ -349,6 +367,7 @@ CNodePtr AkgKernelJsonDecoder::DecodeCNode(const nlohmann::json &cnode_json, con | |||||
| AnfNodePtr AkgKernelJsonDecoder::DecodeOutput(const std::vector<nlohmann::json> &output_descs, | AnfNodePtr AkgKernelJsonDecoder::DecodeOutput(const std::vector<nlohmann::json> &output_descs, | ||||
| const FuncGraphPtr &func_graph) { | const FuncGraphPtr &func_graph) { | ||||
| std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)}; | std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)}; | ||||
| AbstractBasePtrList output_abstract_list; | |||||
| for (const auto &output_desc : output_descs) { | for (const auto &output_desc : output_descs) { | ||||
| std::string name = output_desc[kJsonKeyTensorName]; | std::string name = output_desc[kJsonKeyTensorName]; | ||||
| if (nodes_map_.count(name) == 0) { | if (nodes_map_.count(name) == 0) { | ||||
| @@ -356,11 +375,13 @@ AnfNodePtr AkgKernelJsonDecoder::DecodeOutput(const std::vector<nlohmann::json> | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| outputs.push_back(nodes_map_[name]); | outputs.push_back(nodes_map_[name]); | ||||
| output_abstract_list.push_back(outputs.back()->abstract()); | |||||
| } | } | ||||
| if (outputs.size() == 2) { | if (outputs.size() == 2) { | ||||
| func_graph->set_output(outputs[1]); | func_graph->set_output(outputs[1]); | ||||
| } else { | } else { | ||||
| auto output = func_graph->NewCNode(outputs); | auto output = func_graph->NewCNode(outputs); | ||||
| output->set_abstract(std::make_shared<abstract::AbstractTuple>(output_abstract_list)); | |||||
| func_graph->AddNode(output); | func_graph->AddNode(output); | ||||
| func_graph->set_output(output); | func_graph->set_output(output); | ||||
| } | } | ||||
| @@ -80,10 +80,12 @@ class OpInfoExtractor { | |||||
| } | } | ||||
| void ExtractOutputs(const OpInfoPtr &op_info) { | void ExtractOutputs(const OpInfoPtr &op_info) { | ||||
| // only support single output in op desc. | |||||
| auto io_info = std::make_shared<OpIOInfo>(); | |||||
| io_info->set_name("output"); | |||||
| op_info->add_outputs_ptr(io_info); | |||||
| size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_); | |||||
| for (size_t i = 0; i < output_tensor_num; i++) { | |||||
| auto io_info = std::make_shared<OpIOInfo>(); | |||||
| io_info->set_name("output_" + std::to_string(i)); | |||||
| op_info->add_outputs_ptr(io_info); | |||||
| } | |||||
| } | } | ||||
| bool ExcludeAttr(const std::string &name) { | bool ExcludeAttr(const std::string &name) { | ||||
| @@ -204,8 +206,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con | |||||
| input_desc_json[kJsonKeyName] = input_ptr->name(); | input_desc_json[kJsonKeyName] = input_ptr->name(); | ||||
| input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); | input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); | ||||
| auto input_shape = this->GetInputShape(anf_node, real_input_index); | auto input_shape = this->GetInputShape(anf_node, real_input_index); | ||||
| if (dump_option_.extract_opinfo_from_anfnode && | |||||
| GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { | |||||
| if (!is_basic_op_ && GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { | |||||
| MS_LOG(DEBUG) << "Take input[" << real_input_index << "] of [" << anf_node->DebugString(2) | MS_LOG(DEBUG) << "Take input[" << real_input_index << "] of [" << anf_node->DebugString(2) | ||||
| << "] as const tensor, shape: [" << Vector2Str(input_shape) | << "] as const tensor, shape: [" << Vector2Str(input_shape) | ||||
| << "], value: " << input_desc_json[kJsonKeyValue]; | << "], value: " << input_desc_json[kJsonKeyValue]; | ||||
| @@ -529,9 +530,9 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_json); | MS_EXCEPTION_IF_NULL(kernel_json); | ||||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | std::string op_name = AnfAlgo::GetCNodeName(anf_node); | ||||
| MS_LOG(INFO) << "Akg start generate kernel json desc, full scope name is : " << anf_node->fullname_with_scope(); | |||||
| MS_LOG(DEBUG) << "Akg start generate kernel json desc, full scope name is : " << anf_node->fullname_with_scope(); | |||||
| SetAkgKernelAttrs(anf_node); | SetAkgKernelAttrs(anf_node); | ||||
| dump_option_.extract_opinfo_from_anfnode = false; | |||||
| is_basic_op_ = true; | |||||
| if (!GenerateSingleKernelJson(anf_node, kernel_json)) { | if (!GenerateSingleKernelJson(anf_node, kernel_json)) { | ||||
| MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed."; | MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed."; | ||||
| return false; | return false; | ||||
| @@ -551,8 +552,8 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j | |||||
| return false; | return false; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Akg create kernel json desc success, full scope name is : " << anf_node->fullname_with_scope() | |||||
| << ", json info name is : " << kernel_name_; | |||||
| MS_LOG(DEBUG) << "Akg create kernel json desc success, full scope name is : " << anf_node->fullname_with_scope() | |||||
| << ", json info name is : " << kernel_name_; | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -613,10 +614,11 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| << "]."; | << "]."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size() | |||||
| << "], output_list: [" << input_list.size() << "]."; | |||||
| MS_LOG(DEBUG) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size() | |||||
| << "], output_list: [" << input_list.size() << "]."; | |||||
| std::map<AnfNodePtr, nlohmann::json> node_json_map; | std::map<AnfNodePtr, nlohmann::json> node_json_map; | ||||
| dump_option_.extract_opinfo_from_anfnode = true; | |||||
| is_basic_op_ = false; | |||||
| dump_option_.extract_opinfo_from_anfnode = true; // always extract from anfnode for composite ops. | |||||
| if (!GenSingleJsons(anf_nodes, &node_json_map)) return false; | if (!GenSingleJsons(anf_nodes, &node_json_map)) return false; | ||||
| UpdateTensorName(anf_nodes, &node_json_map); | UpdateTensorName(anf_nodes, &node_json_map); | ||||
| @@ -144,6 +144,7 @@ class AkgKernelJsonGenerator { | |||||
| std::map<std::string, AnfNodePtr> address_node_map_; | std::map<std::string, AnfNodePtr> address_node_map_; | ||||
| std::map<size_t, std::vector<std::string>> sub_graphs_; | std::map<size_t, std::vector<std::string>> sub_graphs_; | ||||
| std::map<size_t, size_t> dim_infos_; | std::map<size_t, size_t> dim_infos_; | ||||
| bool is_basic_op_{false}; | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -393,6 +393,14 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP | |||||
| broadcast_input_node}; | broadcast_input_node}; | ||||
| auto broadcast_to_node_inner = CreateCNode( | auto broadcast_to_node_inner = CreateCNode( | ||||
| atomic_clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_node_)}); | atomic_clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_node_)}); | ||||
| auto device_shape = AnfAlgo::GetOutputDeviceShape(atomic_add_node_, 0); | |||||
| dst_shape_vec.clear(); | |||||
| if (device_shape.empty()) { | |||||
| dst_shape_vec.push_back(1); | |||||
| } else { | |||||
| std::transform(device_shape.begin(), device_shape.end(), std::back_inserter(dst_shape_vec), SizeToLong); | |||||
| } | |||||
| SetNodeAttrSafely("shape", MakeValue(dst_shape_vec), broadcast_to_node_inner); | SetNodeAttrSafely("shape", MakeValue(dst_shape_vec), broadcast_to_node_inner); | ||||
| // Makeup sub-graph. | // Makeup sub-graph. | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -36,88 +36,40 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| constexpr auto kJsonKeyExpandInfo = "expand_info"; | |||||
| #define GET_VALUE_FOR_JSON(JSON, VALUE, VALUE_ELEM, TYPE_NAME, TYPE) \ | |||||
| if (VALUE_ELEM->isa<TYPE_NAME>()) { \ | |||||
| JSON = GetValue<TYPE>(VALUE); \ | |||||
| } | |||||
| nlohmann::json ExpandAttrJsonInfo(const CNodePtr &cnode) { | |||||
| nlohmann::json attrs_json; | |||||
| if (auto prim = GetCNodePrimitive(cnode); prim != nullptr) { | |||||
| auto attrs = prim->attrs(); | |||||
| for (const auto &[k, v] : attrs) { | |||||
| nlohmann::json attr_json; | |||||
| MS_LOG(DEBUG) << "attr key is : " << k << " and value type is : " << v->type_name(); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, Int32Imm, int); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, Int64Imm, int64_t); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, UInt32Imm, uint32_t); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, UInt64Imm, uint64_t); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, FP32Imm, float); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, FP64Imm, double); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, BoolImm, bool); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, StringImm, std::string); | |||||
| if (v->isa<ValueList>() || v->isa<ValueTuple>()) { | |||||
| auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value(); | |||||
| if (!vec.empty()) { | |||||
| MS_LOG(DEBUG) << "value type is : " << vec[0]->type_name(); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], Int32Imm, std::vector<int>); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], Int64Imm, std::vector<int64_t>); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], UInt32Imm, std::vector<uint32_t>); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], UInt64Imm, std::vector<uint64_t>); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], FP32Imm, std::vector<float>); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], FP64Imm, std::vector<double>); | |||||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], StringImm, std::vector<std::string>); | |||||
| } | |||||
| } | |||||
| if (!attr_json.empty()) { | |||||
| attrs_json.push_back(attr_json); | |||||
| } | |||||
| } | |||||
| } | |||||
| return attrs_json; | |||||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | |||||
| std::unordered_set<PrimitivePtr> expand_ops = { | |||||
| prim::kPrimSquare, | |||||
| prim::kPrimGeLUGrad, | |||||
| #if ENABLE_D | |||||
| prim::kPrimTile, | |||||
| prim::kPrimSqrtGrad, | |||||
| prim::kPrimClipByNormNoDivSum, | |||||
| #elif ENABLE_GPU | |||||
| prim::kPrimBiasAdd, | |||||
| prim::kPrimBiasAddGrad, | |||||
| prim::kPrimGeLU, | |||||
| prim::kPrimFusedAdam, | |||||
| prim::kPrimFusedAdamWeightDecay, | |||||
| prim::kPrimReduceMean, | |||||
| prim::kPrimMaximumGrad, | |||||
| prim::kPrimMinimumGrad, | |||||
| prim::kPrimGkDropout, | |||||
| prim::kPrimDropoutGrad, | |||||
| prim::kPrimSoftmax, | |||||
| prim::kPrimLayerNorm, | |||||
| prim::kPrimLayerNormGrad, | |||||
| #endif | |||||
| }; | |||||
| return expand_ops; | |||||
| } | } | ||||
| } // namespace | |||||
| bool ExpandJsonInfo(const CNodePtr &cnode, nlohmann::json *kernel_json) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_json); | |||||
| if (kernel_json->find(kJsonKeyExpandInfo) != kernel_json->end()) { | |||||
| return false; | |||||
| } | |||||
| nlohmann::json expand_info; | |||||
| expand_info[kernel::kJsonKeyAttr] = ExpandAttrJsonInfo(cnode); | |||||
| expand_info[kernel::kJsonKeyName] = AnfAlgo::GetCNodeName(cnode); | |||||
| expand_info[kernel::kJsonKeyProcess] = kernel::GetProcessorStr(cnode); | |||||
| std::vector<nlohmann::json> inputs_info; | |||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) { | |||||
| nlohmann::json input_info; | |||||
| input_info[kernel::kJsonKeyFormat] = AnfAlgo::GetInputFormat(cnode, i); | |||||
| input_info[kernel::kJsonKeyInferShape] = AnfAlgo::GetPrevNodeOutputInferShape(cnode, i); | |||||
| input_info[kernel::kJsonKeyShape] = AnfAlgo::GetInputDeviceShape(cnode, i); | |||||
| input_info[kernel::kJsonKeyInferDataType] = | |||||
| kernel::TypeId2String(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, i)); | |||||
| input_info[kernel::kJsonKeyDataType] = kernel::TypeId2String(AnfAlgo::GetInputDeviceDataType(cnode, i)); | |||||
| inputs_info.push_back(input_info); | |||||
| } | |||||
| expand_info[kernel::kJsonKeyInputDesc] = inputs_info; | |||||
| std::vector<nlohmann::json> outputs_info; | |||||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(cnode); ++i) { | |||||
| nlohmann::json output_info; | |||||
| output_info[kernel::kJsonKeyFormat] = AnfAlgo::GetOutputFormat(cnode, i); | |||||
| output_info[kernel::kJsonKeyInferShape] = AnfAlgo::GetOutputInferShape(cnode, i); | |||||
| output_info[kernel::kJsonKeyShape] = AnfAlgo::GetOutputDeviceShape(cnode, i); | |||||
| output_info[kernel::kJsonKeyInferDataType] = kernel::TypeId2String(AnfAlgo::GetOutputInferDataType(cnode, i)); | |||||
| output_info[kernel::kJsonKeyDataType] = kernel::TypeId2String(AnfAlgo::GetOutputDeviceDataType(cnode, i)); | |||||
| outputs_info.push_back(output_info); | |||||
| } | |||||
| expand_info[kernel::kJsonKeyOutputDesc] = outputs_info; | |||||
| (*kernel_json)[kJsonKeyExpandInfo] = expand_info; | |||||
| return true; | |||||
| bool GraphKernelExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) { | |||||
| DumpOption dump_option; | |||||
| dump_option.extract_opinfo_from_anfnode = true; | |||||
| kernel::AkgKernelJsonGenerator json_generator(dump_option); | |||||
| return json_generator.CollectJson(node, kernel_json); | |||||
| } | } | ||||
| } // namespace | |||||
| FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { | FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { | ||||
| nlohmann::json kernel_json; | nlohmann::json kernel_json; | ||||
| @@ -213,33 +165,11 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { | |||||
| // replace origin node. | // replace origin node. | ||||
| (void)mng->Replace(node, graph_kernel_node); | (void)mng->Replace(node, graph_kernel_node); | ||||
| ToPrimitive(AnfAlgo::GetCNodeFuncGraphPtr(graph_kernel_node)); | |||||
| changed = true; | changed = true; | ||||
| } | } | ||||
| return changed; | return changed; | ||||
| } | } | ||||
| void GraphKernelExpander::ToPrimitive(const FuncGraphPtr &func_graph) const { | |||||
| auto todos = TopoSort(func_graph->get_return()); | |||||
| std::reverse(todos.begin(), todos.end()); | |||||
| auto mng = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| for (const auto &n : todos) { | |||||
| auto cnode = n->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto origin_prim = AnfAlgo::GetCNodePrimitive(cnode); | |||||
| MS_EXCEPTION_IF_NULL(origin_prim); | |||||
| if (!origin_prim->isa<PrimitivePy>()) { | |||||
| continue; | |||||
| } | |||||
| cnode->set_input(0, std::make_shared<ValueNode>(std::make_shared<Primitive>(*origin_prim))); | |||||
| } | |||||
| } | |||||
| bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { | bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { | ||||
| expand_ops_ = GetExpandOps(); | expand_ops_ = GetExpandOps(); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -17,6 +17,7 @@ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ | #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ | ||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <nlohmann/json.hpp> | |||||
| #include "backend/optimizer/common/pass.h" | #include "backend/optimizer/common/pass.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| @@ -31,7 +32,6 @@ class GraphKernelExpander : public Pass { | |||||
| private: | private: | ||||
| FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); | FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); | ||||
| bool DoExpand(const FuncGraphPtr &func_graph); | bool DoExpand(const FuncGraphPtr &func_graph); | ||||
| void ToPrimitive(const FuncGraphPtr &func_graph) const; | |||||
| void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); | void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); | ||||
| AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph, | AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph, | ||||
| const CNodePtr &node); | const CNodePtr &node); | ||||
| @@ -39,6 +39,7 @@ class GraphKernelExpander : public Pass { | |||||
| return std::any_of(expand_ops_.begin(), expand_ops_.end(), | return std::any_of(expand_ops_.begin(), expand_ops_.end(), | ||||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | ||||
| } | } | ||||
| bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json); | |||||
| private: | private: | ||||
| std::unordered_set<PrimitivePtr> expand_ops_; | std::unordered_set<PrimitivePtr> expand_ops_; | ||||
| @@ -39,27 +39,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| void DebugDump(const FuncGraphPtr &graph, std::stringstream *buf) { | |||||
| (*buf) << "Parameters: \n"; | |||||
| const auto ¶meters = graph->parameters(); | |||||
| (*buf) << "size: " << parameters.size() << "\n"; | |||||
| for (const auto &p : parameters) { | |||||
| (*buf) << "\t" << p->DebugString(2) << "\n"; | |||||
| } | |||||
| (*buf) << "ValueNodes: \n"; | |||||
| const auto &value_nodes = graph->value_nodes(); | |||||
| (*buf) << "size: " << value_nodes.size() << "\n"; | |||||
| for (const auto &v : value_nodes) { | |||||
| (*buf) << "\t" << v.first->DebugString(2) << "\n"; | |||||
| } | |||||
| (*buf) << "CNodes: \n"; | |||||
| const auto &all_nodes = graph->nodes(); | |||||
| (*buf) << "size: " << all_nodes.size() << "\n"; | |||||
| for (const auto &n : all_nodes) { | |||||
| (*buf) << "\t" << n->DebugString(2) << "\n"; | |||||
| } | |||||
| } | |||||
| bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) { | bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) { | ||||
| MS_EXCEPTION_IF_NULL(real_outs); | MS_EXCEPTION_IF_NULL(real_outs); | ||||
| if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { | if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { | ||||
| @@ -91,132 +70,6 @@ AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) { | |||||
| return out_spec; | return out_spec; | ||||
| } | } | ||||
| ValueNodePtr ProcessAttrsForCast(const CNodePtr &cnode, const std::string &attr_name) { | |||||
| auto dst_type = AnfAlgo::GetNodeAttr<std::string>(cnode, attr_name); | |||||
| auto type = TypeIdToType(kernel::DtypeToTypeId(dst_type)); | |||||
| auto type_val_node = NewValueNode(type); | |||||
| return type_val_node; | |||||
| } | |||||
| const std::map<std::string, std::function<ValueNodePtr(const CNodePtr &cnode, const std::string &attr_name)>> | |||||
| attrs_process_map = { | |||||
| {kCastOpName, ProcessAttrsForCast}, | |||||
| }; | |||||
| ValueNodePtr ProcessAttrValue(const CNodePtr &cnode, const std::string &attr_name) { | |||||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||||
| if (attrs_process_map.count(op_name) != 0) { | |||||
| return attrs_process_map.at(op_name)(cnode, attr_name); | |||||
| } | |||||
| auto attr_val = AnfAlgo::GetNodeAttr<ValuePtr>(cnode, attr_name); | |||||
| auto attr_val_node = NewValueNode(attr_val); | |||||
| return attr_val_node; | |||||
| } | |||||
| AnfNodePtr ConstAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::unordered_set<size_t> &input_attrs) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_LOG(DEBUG) << "process node: " << cnode->DebugString(2); | |||||
| if (input_attrs.empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto input_names = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames); | |||||
| MS_LOG(DEBUG) << "ori_input_names: " << kernel::Vector2Str(input_names); | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| std::vector<std::string> new_input_names; | |||||
| const auto &inputs = cnode->inputs(); | |||||
| for (size_t i = 0; i < inputs.size() - 1; ++i) { | |||||
| new_input_names.push_back(input_names[i]); | |||||
| } | |||||
| (void)new_inputs.insert(new_inputs.end(), inputs.begin(), inputs.end()); | |||||
| bool need_update = false; | |||||
| for (size_t i = inputs.size() - 1; i < input_names.size(); ++i) { | |||||
| auto attr_name = input_names[i]; | |||||
| if (input_attrs.find(i) == input_attrs.end()) { | |||||
| MS_LOG(WARNING) << "Other type input between tensors and attrs, name: " << attr_name | |||||
| << ", node: " << cnode->DebugString(2); | |||||
| new_input_names.push_back(attr_name); | |||||
| continue; | |||||
| } | |||||
| if (!AnfAlgo::HasNodeAttr(attr_name, cnode)) { | |||||
| MS_LOG(EXCEPTION) << "Attr: " << attr_name << " not found in node: " << cnode->DebugString(2); | |||||
| } | |||||
| // Hardcode. It should convert attrs value according to format, like op ReduceSum. | |||||
| auto attr_val_node = ProcessAttrValue(cnode, attr_name); | |||||
| new_inputs.push_back(attr_val_node); | |||||
| new_input_names.push_back(attr_name); | |||||
| need_update = true; | |||||
| MS_LOG(DEBUG) << "convert attr: " << attr_name << " to input, value: " << attr_val_node; | |||||
| } | |||||
| MS_LOG(DEBUG) << "new_input_names: " << kernel::Vector2Str(new_input_names); | |||||
| if (!need_update) { | |||||
| return nullptr; | |||||
| } | |||||
| auto new_cnode = func_graph->NewCNode(new_inputs); | |||||
| // we do not modify abstract and kernel info. | |||||
| new_cnode->set_abstract(cnode->abstract()); | |||||
| new_cnode->set_kernel_info(cnode->kernel_info_ptr()); | |||||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(new_input_names), new_cnode); | |||||
| return new_cnode; | |||||
| } | |||||
| AnfNodePtr DeleteAttrInInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::unordered_set<size_t> &input_attrs) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_LOG(DEBUG) << "process node: " << cnode->DebugString(2); | |||||
| if (input_attrs.empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto input_names = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames); | |||||
| MS_LOG(DEBUG) << "ori_input_names: " << kernel::Vector2Str(input_names); | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| std::vector<std::string> new_input_names; | |||||
| const auto &inputs = cnode->inputs(); | |||||
| new_inputs.push_back(inputs[0]); | |||||
| bool need_update = false; | |||||
| for (size_t i = 0; i < inputs.size() - 1; ++i) { | |||||
| auto input_node = inputs[i + 1]; | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| // The attrs counts from 0 | |||||
| if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) { | |||||
| auto value_node = input_node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| MS_LOG(DEBUG) << "delete attr input: " << i << " of node: " << cnode->DebugString(2); | |||||
| if (i >= input_names.size()) { | |||||
| MS_LOG(EXCEPTION) << "Index " << i << " is larger than input names size: " << input_names.size(); | |||||
| } | |||||
| need_update = true; | |||||
| } else { | |||||
| new_inputs.push_back(input_node); | |||||
| if (i < input_names.size()) { | |||||
| new_input_names.push_back(input_names[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "new_input_names: " << kernel::Vector2Str(new_input_names); | |||||
| if (!need_update) { | |||||
| return nullptr; | |||||
| } | |||||
| auto new_cnode = func_graph->NewCNode(new_inputs); | |||||
| // we do not modify abstract and kernel info. | |||||
| new_cnode->set_abstract(cnode->abstract()); | |||||
| new_cnode->set_kernel_info(cnode->kernel_info_ptr()); | |||||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(new_input_names), new_cnode); | |||||
| return new_cnode; | |||||
| } | |||||
| AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { | AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { | ||||
| AnfNodePtrList outs; | AnfNodePtrList outs; | ||||
| auto out_node = fg->output(); | auto out_node = fg->output(); | ||||
| @@ -396,59 +249,6 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const | |||||
| AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get()); | AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get()); | ||||
| } | } | ||||
| void ConstAttrToInput(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto mng = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| std::vector<AnfNodePtr> todos; | |||||
| kernel::GetValidKernelNodes(func_graph, &todos); | |||||
| for (const auto &node : todos) { | |||||
| ConstInputToAttrInfoRegister reg; | |||||
| if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(node), ®)) { | |||||
| continue; | |||||
| } | |||||
| auto new_node = ConstAttrToInput(func_graph, node->cast<CNodePtr>(), reg.GetConstInputAttrInfo()); | |||||
| if (new_node != nullptr && new_node != node) { | |||||
| mng->Replace(node, new_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| void DeleteAttrInInput(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto mng = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| std::vector<AnfNodePtr> todos; | |||||
| kernel::GetValidKernelNodes(func_graph, &todos); | |||||
| for (const auto &node : todos) { | |||||
| ConstInputToAttrInfoRegister reg; | |||||
| if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(node), ®)) { | |||||
| continue; | |||||
| } | |||||
| auto new_node = DeleteAttrInInput(func_graph, node->cast<CNodePtr>(), reg.GetConstInputAttrInfo()); | |||||
| if (new_node != nullptr && new_node != node) { | |||||
| mng->Replace(node, new_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) { | |||||
| AnfNodePtrList res; | |||||
| if (outs.size() <= 1) { | |||||
| return outs; | |||||
| } | |||||
| for (auto out : outs) { | |||||
| AnfNodePtrList real_outs; | |||||
| if (IsMakeTupleOut(out, &real_outs)) { | |||||
| res.insert(res.end(), real_outs.begin(), real_outs.end()); | |||||
| continue; | |||||
| } | |||||
| res.push_back(out); | |||||
| } | |||||
| return res; | |||||
| } | |||||
| AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | ||||
| const AnfNodePtrList &outputs) { | const AnfNodePtrList &outputs) { | ||||
| auto func_node = NewValueNode(fg); | auto func_node = NewValueNode(fg); | ||||
| @@ -661,68 +461,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||||
| MS_LOG(ERROR) << "Akg decode json to graph failed."; | MS_LOG(ERROR) << "Akg decode json to graph failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>(); | |||||
| auto mng = resource->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| mng->AddFuncGraph(fg); | |||||
| ConstAttrToInput(fg); | |||||
| std::stringstream buf; | |||||
| buf << "===================== graph after ConstAttrToInput " << fg->ToString() << " =====================\n"; | |||||
| DebugDump(fg, &buf); | |||||
| MS_LOG(DEBUG) << buf.str(); | |||||
| // Do infer and specialize. | |||||
| AbstractBasePtrList args_spec_list; | |||||
| std::for_each(inputs.begin(), inputs.end(), | |||||
| [&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); }); | |||||
| auto infer_fg = pipeline::Renormalize(resource, fg, args_spec_list); | |||||
| if (infer_fg == nullptr) { | |||||
| MS_LOG(ERROR) << "Infer decoded graph failed."; | |||||
| return nullptr; | |||||
| } | |||||
| buf.str(""); | |||||
| buf << "===================== graph after Renormalize " << infer_fg->ToString() << " =====================\n"; | |||||
| DebugDump(infer_fg, &buf); | |||||
| MS_LOG(DEBUG) << buf.str(); | |||||
| // delete no use inputs(attrs), like op ReduceSum(axis). | |||||
| DeleteAttrInInput(infer_fg); | |||||
| buf.str(""); | |||||
| buf << "===================== graph after DeleteAttrInInput " << infer_fg->ToString() << " =====================\n"; | |||||
| DebugDump(infer_fg, &buf); | |||||
| MS_LOG(DEBUG) << buf.str(); | |||||
| // clone a new graph. | |||||
| auto new_fg = TransformableClone(infer_fg, std::make_shared<TraceTransform>("akg_decode")); | |||||
| return new_fg; | |||||
| } | |||||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | |||||
| std::unordered_set<PrimitivePtr> expand_ops = { | |||||
| prim::kPrimSquare, | |||||
| prim::kPrimGeLUGrad, | |||||
| #if ENABLE_D | |||||
| prim::kPrimTile, | |||||
| prim::kPrimSqrtGrad, | |||||
| prim::kPrimClipByNormNoDivSum, | |||||
| #elif ENABLE_GPU | |||||
| prim::kPrimBiasAdd, | |||||
| prim::kPrimBiasAddGrad, | |||||
| prim::kPrimGeLU, | |||||
| prim::kPrimFusedAdam, | |||||
| prim::kPrimFusedAdamWeightDecay, | |||||
| prim::kPrimReduceMean, | |||||
| prim::kPrimMaximumGrad, | |||||
| prim::kPrimMinimumGrad, | |||||
| prim::kPrimGkDropout, | |||||
| prim::kPrimDropoutGrad, | |||||
| prim::kPrimSoftmax, | |||||
| prim::kPrimLayerNorm, | |||||
| prim::kPrimLayerNormGrad, | |||||
| #endif | |||||
| }; | |||||
| return expand_ops; | |||||
| return fg; | |||||
| } | } | ||||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) { | std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) { | ||||
| @@ -61,7 +61,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph( | |||||
| AnfNodePtrList *src_outputs = nullptr); | AnfNodePtrList *src_outputs = nullptr); | ||||
| void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | ||||
| const AnfNodePtrList &outputs, kernel::Processor processor); | const AnfNodePtrList &outputs, kernel::Processor processor); | ||||
| AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs); | |||||
| AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | ||||
| const AnfNodePtrList &outputs); | const AnfNodePtrList &outputs); | ||||
| void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, | void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, | ||||
| @@ -74,7 +73,6 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n | |||||
| std::map<std::string, AnfNodePtr> *address_node_map); | std::map<std::string, AnfNodePtr> *address_node_map); | ||||
| bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); | bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); | ||||
| FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs); | FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs); | ||||
| std::unordered_set<PrimitivePtr> GetExpandOps(); | |||||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | ||||
| std::vector<PrimitivePtr> GetFusibleOpList(); | std::vector<PrimitivePtr> GetFusibleOpList(); | ||||
| bool IsBasicFuseOp(const AnfNodePtr &node); | bool IsBasicFuseOp(const AnfNodePtr &node); | ||||