From: @dayschan Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_doutags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -20,18 +20,31 @@ from mindspore import log as logger | |||
| import mindspore._extends.graph_kernel.expanders as expanders | |||
| def extract_expand_info(kernel_info): | |||
| """Convert the json into a more friendly format""" | |||
| input_desc = [] | |||
| if 'input_desc' in kernel_info and kernel_info['input_desc']: | |||
| for desc in kernel_info['input_desc']: | |||
| input_desc += desc | |||
| attrs = {} | |||
| if 'attr' in kernel_info and kernel_info['attr']: | |||
| for attr in kernel_info["attr"]: | |||
| attrs[attr["name"]] = attr["value"] | |||
| expand_info = { | |||
| "name": kernel_info["name"], | |||
| "input_desc": input_desc, | |||
| "output_desc": kernel_info["output_desc"], | |||
| "attr": attrs, | |||
| "process": kernel_info["process"], | |||
| } | |||
| return expand_info | |||
| def get_op_expander(json_str: str): | |||
| """get op expander by json info""" | |||
| try: | |||
| kernel_info = json.loads(json_str) | |||
| expand_info = kernel_info['expand_info'] | |||
| if 'name' not in expand_info: | |||
| logger.error("expand info have no op name") | |||
| return None | |||
| if 'process' not in expand_info: | |||
| logger.error("expand info have no processor info") | |||
| return None | |||
| expand_info = extract_expand_info(kernel_info) | |||
| processor = expand_info['process'] | |||
| op_name = str(expand_info['name']).lower() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -21,20 +21,15 @@ def expand_dropoutgrad(expand_info): | |||
| # get op info. | |||
| dy_desc = expand_info['input_desc'][0] | |||
| mask_desc = expand_info['input_desc'][1] | |||
| keep_prob = None | |||
| for attr in expand_info['attr']: | |||
| if 'keep_prob' in attr: | |||
| keep_prob = attr['keep_prob'] | |||
| if keep_prob is None: | |||
| raise RuntimeError("keep_prob does not exist in attrs.") | |||
| # generate a graph. | |||
| keep_prob = expand_info['attr']['keep_prob'] | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) | |||
| input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format']) | |||
| graph_scope.set_input(input_dy, input_mask) | |||
| r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob, "DefaultFormat") | |||
| r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob) | |||
| # create op. | |||
| result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) | |||
| result = graph_builder.emit('Mul', [result, input_mask]) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -40,17 +40,16 @@ def expand_gelu(expand_info): | |||
| # cal y | |||
| mul_0 = graph_builder.emit('Mul', [input_x, input_x]) | |||
| pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) | |||
| const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format']) | |||
| const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE) | |||
| mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) | |||
| tanh_res = graph_builder.emit('Add', [input_x, mul_1]) | |||
| const_csvalue_sqrt_two_div_pi = graph_builder.value( | |||
| tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format']) | |||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI) | |||
| y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) | |||
| # cal gelu(x) | |||
| tanh_y = graph_builder.emit('Tanh', [y]) | |||
| const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format']) | |||
| const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format']) | |||
| const_one = graph_builder.value(tanh_y.dtype, ONE) | |||
| const_half = graph_builder.value(tanh_y.dtype, HALF) | |||
| tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) | |||
| mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | |||
| result = graph_builder.emit('Mul', [const_half, mul_x]) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -45,12 +45,11 @@ def expand_gelugrad(expand_info): | |||
| graph_scope.set_input(input_dy, input_x, input_y) | |||
| # create some const var | |||
| const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format']) | |||
| const_csvalue_sqrt_two_div_pi = graph_builder.value( | |||
| input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc_0['format']) | |||
| const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI, input_desc_0['format']) | |||
| const_one = graph_builder.value(input_dy.dtype, ONE, input_desc_0['format']) | |||
| const_half = graph_builder.value(input_dy.dtype, HALF, input_desc_0['format']) | |||
| const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE) | |||
| const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI) | |||
| const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI) | |||
| const_one = graph_builder.value(input_dy.dtype, ONE) | |||
| const_half = graph_builder.value(input_dy.dtype, HALF) | |||
| # cal mul_right | |||
| mul_double = graph_builder.emit('Mul', [input_x, input_x]) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -21,25 +21,20 @@ def expand_gkdropout(expand_info): | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| maks_desc = expand_info['input_desc'][1] | |||
| keep_prob = None | |||
| for attr in expand_info['attr']: | |||
| if 'keep_prob' in attr: | |||
| keep_prob = attr['keep_prob'] | |||
| if keep_prob is None: | |||
| raise RuntimeError("keep_prob does not exist in attrs.") | |||
| # generate a graph. | |||
| keep_prob = expand_info['attr']['keep_prob'] | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format']) | |||
| graph_scope.set_input(input_x, input_mask) | |||
| keep_prob_v = graph_builder.value(input_x.dtype, keep_prob, "DefaultFormat") | |||
| r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob, "DefaultFormat") | |||
| keep_prob_v = graph_builder.value(input_x.dtype, keep_prob) | |||
| r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob) | |||
| if input_mask.dtype != input_x.dtype: | |||
| input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype}) | |||
| mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type | |||
| mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type | |||
| mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype}) | |||
| # compute result | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -23,16 +23,10 @@ def expand_layernorm(expand_info): | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| attrs = expand_info['attr'] | |||
| begin_norm_axis = None | |||
| epsilon = None | |||
| for item in attrs: | |||
| if 'begin_norm_axis' in item: | |||
| begin_norm_axis = item['begin_norm_axis'] | |||
| if 'epsilon' in item: | |||
| epsilon = item['epsilon'] | |||
| graph_builder = builder.GraphBuilder() | |||
| begin_norm_axis = attrs['begin_norm_axis'] | |||
| epsilon = attrs['epsilon'] | |||
| # generate a graph. | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| @@ -52,7 +46,7 @@ def expand_layernorm(expand_info): | |||
| for i in reduce_axis: | |||
| reduce_elts *= shape_x[i] | |||
| mean_cof = 1.0 / reduce_elts | |||
| mean_cof_v = graph_builder.value(input_x.dtype, mean_cof, input_x.data_format) | |||
| mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) | |||
| # Calculate mean | |||
| mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) | |||
| @@ -67,7 +61,7 @@ def expand_layernorm(expand_info): | |||
| # Calculate normalize | |||
| normalize_sub = graph_builder.emit('Sub', [input_x, mean]) | |||
| epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format) | |||
| epsilon_v = graph_builder.value(input_x.dtype, epsilon) | |||
| normalize_add = graph_builder.emit('Add', [variance, epsilon_v]) | |||
| normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add]) | |||
| normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,16 +24,10 @@ def expand_layernormgrad(expand_info): | |||
| var_desc = expand_info['input_desc'][2] | |||
| mean_desc = expand_info['input_desc'][3] | |||
| gamma_desc = expand_info['input_desc'][4] | |||
| begin_norm_axis = None | |||
| begin_params_axis = None | |||
| epsilon = 1e-11 | |||
| for item in expand_info['attr']: | |||
| if 'begin_norm_axis' in item: | |||
| begin_norm_axis = item['begin_norm_axis'] | |||
| if 'begin_params_axis' in item: | |||
| begin_params_axis = item['begin_params_axis'] | |||
| if 'epsilon' in item: | |||
| epsilon = item['epsilon'] | |||
| attrs = expand_info['attr'] | |||
| begin_norm_axis = attrs['begin_norm_axis'] | |||
| begin_params_axis = attrs['begin_params_axis'] | |||
| epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-11 | |||
| shape_x = x_desc['shape'] | |||
| if begin_norm_axis < 0: | |||
| @@ -57,13 +51,13 @@ def expand_layernormgrad(expand_info): | |||
| graph_scope.set_input(x, dy, variance, mean, gamma) | |||
| # set some constant val. | |||
| eps = graph_builder.value(x.dtype, epsilon, x.data_format) | |||
| const_one = graph_builder.value(x.dtype, 1.0, x.data_format) | |||
| const_neg_half = graph_builder.value(x.dtype, -0.5, x.data_format) | |||
| const_neg_two = graph_builder.value(x.dtype, -2.0, x.data_format) | |||
| const_two = graph_builder.value(x.dtype, 2.0, x.data_format) | |||
| const_neg_one = graph_builder.value(x.dtype, -1.0, x.data_format) | |||
| mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format) | |||
| eps = graph_builder.value(x.dtype, epsilon) | |||
| const_one = graph_builder.value(x.dtype, 1.0) | |||
| const_neg_half = graph_builder.value(x.dtype, -0.5) | |||
| const_neg_two = graph_builder.value(x.dtype, -2.0) | |||
| const_two = graph_builder.value(x.dtype, 2.0) | |||
| const_neg_one = graph_builder.value(x.dtype, -1.0) | |||
| mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size)) | |||
| # cal dg db | |||
| var_eps = graph_builder.emit('Add', [variance, eps]) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -20,11 +20,7 @@ def expand_logsoftmax(expand_info): | |||
| """LogSoftmax expander""" | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| attrs = expand_info['attr'] | |||
| axis = None | |||
| for item in attrs: | |||
| if 'axis' in item: | |||
| axis = item['axis'] | |||
| axis = expand_info['attr']['axis'] | |||
| graph_builder = builder.GraphBuilder() | |||
| if isinstance(axis, int): | |||
| axis = (axis,) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -21,11 +21,7 @@ def expand_logsoftmaxgrad(expand_info): | |||
| # get op info. | |||
| input_desc_0 = expand_info['input_desc'][0] | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| attrs = expand_info['attr'] | |||
| axis = None | |||
| for item in attrs: | |||
| if 'axis' in item: | |||
| axis = item['axis'] | |||
| axis = expand_info['attr']['axis'] | |||
| graph_builder = builder.GraphBuilder() | |||
| if isinstance(axis, int): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -23,16 +23,10 @@ def expand_maximumgrad(expand_info): | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| attrs = expand_info['attr'] | |||
| grad_x = None | |||
| grad_y = None | |||
| for item in attrs: | |||
| if 'grad_x' in item: | |||
| grad_x = item['grad_x'] | |||
| if 'grad_y' in item: | |||
| grad_y = item['grad_y'] | |||
| graph_builder = builder.GraphBuilder() | |||
| grad_x = attrs['grad_x'] if 'grad_x' in attrs else True | |||
| grad_y = attrs['grad_y'] if 'grad_y' in attrs else True | |||
| # generate a graph. | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -23,15 +23,10 @@ def expand_minimumgrad(expand_info): | |||
| input_desc_1 = expand_info['input_desc'][1] | |||
| input_desc_2 = expand_info['input_desc'][2] | |||
| attrs = expand_info['attr'] | |||
| grad_x = None | |||
| grad_y = None | |||
| for item in attrs: | |||
| if 'grad_x' in item: | |||
| grad_x = item['grad_x'] | |||
| if 'grad_y' in item: | |||
| grad_y = item['grad_y'] | |||
| grad_x = attrs['grad_x'] if 'grad_x' in attrs else True | |||
| grad_y = attrs['grad_y'] if 'grad_y' in attrs else True | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -18,20 +18,13 @@ from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| def expand_reducemean(expand_info): | |||
| """ReduceMean expander""" | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| attrs = expand_info['attr'] | |||
| axis = None | |||
| keep_dims = None | |||
| for item in attrs: | |||
| if 'axis' in item: | |||
| axis = item['axis'] | |||
| if 'keep_dims' in item: | |||
| keep_dims = item['keep_dims'] | |||
| graph_builder = builder.GraphBuilder() | |||
| axis = attrs['axis'] | |||
| keep_dims = attrs['keep_dims'] | |||
| # generate a graph. | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| @@ -49,7 +42,7 @@ def expand_reducemean(expand_info): | |||
| for idx in axis: | |||
| all_shape *= x_shape[idx] | |||
| all_shape_value = graph_builder.value(input_x.dtype, all_shape, input_x.data_format) | |||
| all_shape_value = graph_builder.value(input_x.dtype, all_shape) | |||
| if not axis: | |||
| sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims}) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -18,16 +18,10 @@ from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| def expand_softmax(expand_info): | |||
| """Softmax expander""" | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| attrs = expand_info['attr'] | |||
| axis = None | |||
| for item in attrs: | |||
| if 'axis' in item: | |||
| axis = item['axis'] | |||
| graph_builder = builder.GraphBuilder() | |||
| axis = expand_info['attr']['axis'] | |||
| # generate a graph. | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -34,7 +34,7 @@ def expand_sqrtgrad(expand_info): | |||
| graph_scope.set_input(input_x, input_dout) | |||
| # cal result | |||
| const_two = graph_builder.value(input_x.dtype, 2, input_x.data_format) | |||
| const_two = graph_builder.value(input_x.dtype, 2) | |||
| dividend = graph_builder.emit('Mul', [input_x, const_two]) | |||
| result = graph_builder.emit('RealDiv', [input_dout, dividend]) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -31,7 +31,7 @@ def expand_tanhgrad(expand_info): | |||
| # create tensor input. | |||
| input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||
| input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||
| const_one = graph_builder.value(input_y.dtype, ONE, input_y.data_format) | |||
| const_one = graph_builder.value(input_y.dtype, ONE) | |||
| graph_scope.set_input(input_y, input_dy) | |||
| # cal result | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -18,18 +18,11 @@ from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| def expand_tile(expand_info): | |||
| """Tile expander""" | |||
| # get op info. | |||
| input_desc = expand_info['input_desc'][0] | |||
| attrs = expand_info['attr'] | |||
| multiples = None | |||
| for item in attrs: | |||
| if 'multiples' in item: | |||
| multiples = item['multiples'] | |||
| multiples = expand_info['attr']['multiples'] | |||
| output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples) | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| graph_builder = builder.GraphBuilder() | |||
| with graph_builder.graph_scope('main') as graph_scope: | |||
| # create tensor input. | |||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||
| @@ -15,7 +15,7 @@ | |||
| """GraphKernel model builder""" | |||
| import copy | |||
| from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy | |||
| from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy, DataFormat | |||
| def get_tile_output_shape(shape, multiples): | |||
| @@ -70,7 +70,7 @@ class OpInfer: | |||
| real_shape = [] | |||
| for i, _ in enumerate(shape): | |||
| if i not in attrs['reduce_axis']: | |||
| if i not in attrs['reduce_axis'] and i - len(shape) not in attrs['reduce_axis']: | |||
| real_shape.append(shape[i]) | |||
| return real_shape | |||
| @@ -106,7 +106,15 @@ class OpInfer: | |||
| @staticmethod | |||
| def default_infer_format_func(inputs, attrs): | |||
| """Infer format""" | |||
| return inputs[0].data_format | |||
| result = inputs[0].data_format | |||
| # default_format and other_format results in other_format | |||
| for input_tensor in inputs[1:]: | |||
| data_format = input_tensor.data_format | |||
| if data_format != DataFormat.DEFAULT: | |||
| if result not in [DataFormat.DEFAULT, data_format]: | |||
| raise RuntimeError("Incompatible data format %s and %s" % (data_format, result)) | |||
| result = data_format | |||
| return result | |||
| infer_shape_func = { | |||
| # add special infer func here | |||
| @@ -114,13 +122,20 @@ class OpInfer: | |||
| 'Reshape': lambda inputs, attrs: attrs["shape"], | |||
| 'BroadcastTo': lambda inputs, attrs: attrs["shape"], | |||
| 'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0], | |||
| 'ExpandDims': lambda inputs, attrs: list(inputs[0].shape).insert(attrs["axis"], 1), | |||
| } | |||
| infer_dtype_func = { | |||
| # add special infer func here | |||
| 'Cast': lambda inputs, attrs: attrs['dst_type'], | |||
| 'Less': lambda inputs, attrs: "bool", | |||
| 'LessEqual': lambda inputs, attrs: "bool", | |||
| 'Equal': lambda inputs, attrs: "bool", | |||
| 'Greater': lambda inputs, attrs: "bool", | |||
| 'GreaterEqual': lambda inputs, attrs: "bool", | |||
| } | |||
| infer_format_func = { | |||
| # add special infer func here | |||
| 'Reshape': lambda inputs, attrs: "DefaultFormat", | |||
| } | |||
| @classmethod | |||
| @@ -188,18 +203,12 @@ class GraphBuilder: | |||
| shape = [1] | |||
| return Tensor(name, shape, dtype, data_format, para_type=para_type) | |||
| def value(self, dtype, value, data_format, name=None): | |||
| def value(self, dtype, value, name=None): | |||
| """Create a new Value""" | |||
| if name in (None, ''): | |||
| name = self._alloc_tensor_name() | |||
| if dtype == "float16": | |||
| # For float16 value, it will be changed to float32 wrongly. And there is no good solution for now. | |||
| # So instead just declare float32 value and then cast it to float16. | |||
| v_fp32 = Value(name, "float32", value, data_format) | |||
| v = self.emit("Cast", [v_fp32], attrs={"dst_type": "float16"}) | |||
| else: | |||
| v = Value(name, dtype, value, data_format) | |||
| v = Value(name, dtype, value) | |||
| return v | |||
| def op(self, prim, output, inputs, attrs=None): | |||
| @@ -19,8 +19,7 @@ | |||
| #include <memory> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| @@ -46,6 +45,62 @@ namespace { | |||
| constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; | |||
| constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; | |||
| class AbstractShapeCreator { | |||
| public: | |||
| using AbstractShapeTransferFunc = std::function<ShapeVector(const ShapeVector &)>; | |||
| /** | |||
| * Get an abstract shape. | |||
| * For a given device_shape and format, the available abstract_shape is not unique, | |||
| * this interface only returns a legal abstract_shape without considering padding | |||
| * so that the AnfAlgo's get device shape interface can get the right device_shape. | |||
| */ | |||
| static ShapeVector GetFakeAbstractShape(const ShapeVector &device_shape, const std::string &format) { | |||
| const std::map<std::string, AbstractShapeTransferFunc> fmap{ | |||
| {kOpFormat_NCHW, NchwAbstractShape}, | |||
| {kOpFormat_NHWC, NhwcAbstractShape}, | |||
| {kOpFormat_FRAC_NZ, FractalNzAbstractShape}, | |||
| }; | |||
| if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | |||
| return device_shape; | |||
| } | |||
| auto iter = fmap.find(format); | |||
| if (iter == fmap.end()) { | |||
| MS_LOG(WARNING) << "Unexpected format[" << format << "]"; | |||
| return device_shape; | |||
| } | |||
| return iter->second(device_shape); | |||
| } | |||
| private: | |||
| static ShapeVector NchwAbstractShape(const ShapeVector &device_shape) { return device_shape; } | |||
| static ShapeVector NhwcAbstractShape(const ShapeVector &device_shape) { | |||
| if (device_shape.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "Shape size of NHWC should be 4, but got " << device_shape.size(); | |||
| } | |||
| return {device_shape[0], device_shape[3], device_shape[1], device_shape[2]}; | |||
| } | |||
| static ShapeVector FractalNzAbstractShape(const ShapeVector &device_shape) { | |||
| if (device_shape.size() == 1 && (device_shape[0] == 1 || device_shape[0] % kCubeSize == 0)) { | |||
| return device_shape; | |||
| } | |||
| if (device_shape.size() < 4) { | |||
| MS_LOG(EXCEPTION) << "Shape size of FRACTAL_NZ should >= 4, but got " << device_shape.size(); | |||
| } | |||
| ShapeVector shape; | |||
| size_t dims = device_shape.size(); | |||
| size_t batch = dims - 4; | |||
| for (size_t i = 0; i < batch; ++i) { | |||
| shape.push_back(device_shape[i]); | |||
| } | |||
| int64_t m = device_shape[dims - 3] * device_shape[dims - 2]; | |||
| int64_t n = device_shape[dims - 4] * device_shape[dims - 1]; | |||
| shape.push_back(m); | |||
| shape.push_back(n); | |||
| return shape; | |||
| } | |||
| }; | |||
| class CNodeDecoder { | |||
| public: | |||
| explicit CNodeDecoder(std::map<std::string, AnfNodePtr> *nodes_map) : nodes_map_(*nodes_map) {} | |||
| @@ -66,6 +121,7 @@ class CNodeDecoder { | |||
| return nullptr; | |||
| } | |||
| CreateKernelInfo(processor); | |||
| CreateAbstract(); | |||
| return cnode_; | |||
| } | |||
| @@ -117,12 +173,8 @@ class CNodeDecoder { | |||
| bool DecodeInputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) { | |||
| std::string op_name = cnode_json[kJsonKeyName]; | |||
| // new primitive. | |||
| auto primitive = GetPrimitive(op_name); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Create primitive failed."; | |||
| return false; | |||
| } | |||
| auto primitive = CreatePrimitiveWithAttrs(op_name); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| // collect inputs. | |||
| auto primitive_v = NewValueNode(primitive); | |||
| @@ -142,6 +194,7 @@ class CNodeDecoder { | |||
| } | |||
| input_formats_.push_back(input_desc[kJsonKeyFormat]); | |||
| input_types_.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType])); | |||
| input_shapes_.push_back(input_desc[kJsonKeyShape]); | |||
| } | |||
| // new cnode. | |||
| cnode_ = func_graph->NewCNode(inputs); | |||
| @@ -160,6 +213,7 @@ class CNodeDecoder { | |||
| nlohmann::json output_desc = output_descs[0]; | |||
| output_formats_.push_back(output_desc[kJsonKeyFormat]); | |||
| output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); | |||
| output_shapes_.push_back(output_desc[kJsonKeyShape]); | |||
| nodes_map_[output_desc[kJsonKeyTensorName]] = cnode_; | |||
| } else { | |||
| // multi outputs. | |||
| @@ -167,6 +221,7 @@ class CNodeDecoder { | |||
| nlohmann::json output_desc = output_descs[j]; | |||
| output_formats_.push_back(output_desc[kJsonKeyFormat]); | |||
| output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); | |||
| output_shapes_.push_back(output_desc[kJsonKeyShape]); | |||
| auto get_item = | |||
| func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_, NewValueNode(SizeToLong(j))}); | |||
| func_graph->AddNode(get_item); | |||
| @@ -219,72 +274,29 @@ class CNodeDecoder { | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode_.get()); | |||
| } | |||
| ValuePtr CreatOpInstance(const std::string &op_name, const std::vector<ValuePtr> &attrs) { | |||
| // python utils. | |||
| constexpr auto kGetPythonOpFunc = "_get_python_op"; | |||
| constexpr auto kParallelUtilsModule = "mindspore.parallel._utils"; | |||
| // almost all ops are defined in this path. | |||
| constexpr auto kOperationsModule = "mindspore.ops.operations"; | |||
| py::module mod = py::module::import(kOperationsModule); | |||
| if (!py::hasattr(mod, op_name.c_str())) { | |||
| MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name; | |||
| return nullptr; | |||
| } | |||
| std::vector<py::object> arg_list; | |||
| (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), | |||
| [](const ValuePtr &attr) { return ValuePtrToPyData(attr); }); | |||
| py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule, | |||
| op_name, arg_list); | |||
| ValuePtr op_instance = nullptr; | |||
| bool succ = parse::ConvertData(obj, &op_instance); | |||
| if (!succ) { | |||
| MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed."; | |||
| return nullptr; | |||
| } | |||
| return op_instance; | |||
| void CreateAbstract() { | |||
| auto shape = AbstractShapeCreator::GetFakeAbstractShape(output_shapes_[0], output_formats_[0]); | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(output_types_[0]), shape); | |||
| cnode_->set_abstract(abstract); | |||
| } | |||
| const std::map<std::string, std::vector<std::string>> op_attrs_map_ = { | |||
| {kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}}, | |||
| {kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}}, | |||
| {kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}}, | |||
| {kBroadcastToOpName, std::vector<std::string>{kAttrShape}}, | |||
| }; | |||
| PrimitivePtr GetPrimitive(const std::string &op_name) { | |||
| PrimitivePtr primitive{nullptr}; | |||
| if (op_attrs_map_.count(op_name) == 0) { | |||
| // no attrs for op instance. | |||
| primitive = CreatOpInstance(op_name, std::vector<ValuePtr>{})->cast<PrimitivePtr>(); | |||
| } else { | |||
| // make attrs for op instance. | |||
| std::vector<ValuePtr> op_attrs; | |||
| const auto &attr_names = op_attrs_map_.at(op_name); | |||
| for (const auto &attr_name : attr_names) { | |||
| if (cnode_attrs_.count(attr_name) == 0) { | |||
| MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found."; | |||
| return nullptr; | |||
| } | |||
| op_attrs.push_back(cnode_attrs_.at(attr_name)); | |||
| } | |||
| primitive = CreatOpInstance(op_name, op_attrs)->cast<PrimitivePtr>(); | |||
| } | |||
| if (primitive != nullptr) { | |||
| for (const auto &attr : cnode_attrs_) { | |||
| primitive->AddAttr(attr.first, attr.second); | |||
| } | |||
| PrimitivePtr CreatePrimitiveWithAttrs(const std::string &op_name) { | |||
| auto primitive = std::make_shared<Primitive>(op_name); | |||
| for (const auto &attr : cnode_attrs_) { | |||
| primitive->AddAttr(attr.first, attr.second); | |||
| } | |||
| return primitive; | |||
| } | |||
| ScalarPtr DecodeScalar(const nlohmann::json &scalar_json) { | |||
| tensor::TensorPtr DecodeScalar(const nlohmann::json &scalar_json) { | |||
| auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]); | |||
| switch (type_id) { | |||
| case kNumberTypeFloat16: | |||
| return std::make_shared<tensor::Tensor>(static_cast<float>(scalar_json[kJsonKeyValue]), kFloat16); | |||
| case kNumberTypeFloat32: | |||
| return std::make_shared<FP32Imm>(scalar_json[kJsonKeyValue]); | |||
| return std::make_shared<tensor::Tensor>(static_cast<float>(scalar_json[kJsonKeyValue]), kFloat32); | |||
| case kNumberTypeInt32: | |||
| return std::make_shared<Int32Imm>(scalar_json[kJsonKeyValue]); | |||
| return std::make_shared<tensor::Tensor>(static_cast<int64_t>(scalar_json[kJsonKeyValue]), kInt32); | |||
| default: | |||
| MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType]; | |||
| break; | |||
| @@ -294,9 +306,8 @@ class CNodeDecoder { | |||
| ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) { | |||
| MS_LOG(DEBUG) << "start decode value node, " << value_json; | |||
| auto scalar = DecodeScalar(value_json); | |||
| auto tensor = ScalarToTensor(scalar); | |||
| auto tensor = DecodeScalar(value_json); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto value_node = std::make_shared<ValueNode>(tensor); | |||
| value_node->set_abstract(tensor->ToAbstract()); | |||
| // create kernel_info fo new value node. | |||
| @@ -319,6 +330,8 @@ class CNodeDecoder { | |||
| std::vector<std::string> output_formats_; | |||
| std::vector<TypeId> input_types_; | |||
| std::vector<TypeId> output_types_; | |||
| std::vector<ShapeVector> input_shapes_; | |||
| std::vector<ShapeVector> output_shapes_; | |||
| CNodePtr cnode_{nullptr}; | |||
| }; | |||
| } // namespace | |||
| @@ -329,11 +342,16 @@ ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶met | |||
| ParameterPtr new_parameter = func_graph->add_parameter(); | |||
| std::string name = parameter_json[kJsonKeyTensorName]; | |||
| new_parameter->set_name(name); | |||
| std::string format = parameter_json[kJsonKeyFormat]; | |||
| TypeId dtype = DtypeToTypeId(parameter_json[kJsonKeyDataType]); | |||
| ShapeVector shape = AbstractShapeCreator::GetFakeAbstractShape(parameter_json[kJsonKeyShape], format); | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(dtype), shape); | |||
| new_parameter->set_abstract(abstract); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| new_parameter->set_kernel_info(kernel_info); | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| builder->SetOutputsFormat(std::vector<std::string>{parameter_json[kJsonKeyFormat]}); | |||
| builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(parameter_json[kJsonKeyDataType])}); | |||
| builder->SetOutputsFormat(std::vector<std::string>{format}); | |||
| builder->SetOutputsDeviceType(std::vector<TypeId>{dtype}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), new_parameter.get()); | |||
| nodes_map_[name] = new_parameter; | |||
| return new_parameter; | |||
| @@ -349,6 +367,7 @@ CNodePtr AkgKernelJsonDecoder::DecodeCNode(const nlohmann::json &cnode_json, con | |||
| AnfNodePtr AkgKernelJsonDecoder::DecodeOutput(const std::vector<nlohmann::json> &output_descs, | |||
| const FuncGraphPtr &func_graph) { | |||
| std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)}; | |||
| AbstractBasePtrList output_abstract_list; | |||
| for (const auto &output_desc : output_descs) { | |||
| std::string name = output_desc[kJsonKeyTensorName]; | |||
| if (nodes_map_.count(name) == 0) { | |||
| @@ -356,11 +375,13 @@ AnfNodePtr AkgKernelJsonDecoder::DecodeOutput(const std::vector<nlohmann::json> | |||
| return nullptr; | |||
| } | |||
| outputs.push_back(nodes_map_[name]); | |||
| output_abstract_list.push_back(outputs.back()->abstract()); | |||
| } | |||
| if (outputs.size() == 2) { | |||
| func_graph->set_output(outputs[1]); | |||
| } else { | |||
| auto output = func_graph->NewCNode(outputs); | |||
| output->set_abstract(std::make_shared<abstract::AbstractTuple>(output_abstract_list)); | |||
| func_graph->AddNode(output); | |||
| func_graph->set_output(output); | |||
| } | |||
| @@ -80,10 +80,12 @@ class OpInfoExtractor { | |||
| } | |||
| void ExtractOutputs(const OpInfoPtr &op_info) { | |||
| // only support single output in op desc. | |||
| auto io_info = std::make_shared<OpIOInfo>(); | |||
| io_info->set_name("output"); | |||
| op_info->add_outputs_ptr(io_info); | |||
| size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_); | |||
| for (size_t i = 0; i < output_tensor_num; i++) { | |||
| auto io_info = std::make_shared<OpIOInfo>(); | |||
| io_info->set_name("output_" + std::to_string(i)); | |||
| op_info->add_outputs_ptr(io_info); | |||
| } | |||
| } | |||
| bool ExcludeAttr(const std::string &name) { | |||
| @@ -204,8 +206,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con | |||
| input_desc_json[kJsonKeyName] = input_ptr->name(); | |||
| input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); | |||
| auto input_shape = this->GetInputShape(anf_node, real_input_index); | |||
| if (dump_option_.extract_opinfo_from_anfnode && | |||
| GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { | |||
| if (!is_basic_op_ && GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { | |||
| MS_LOG(DEBUG) << "Take input[" << real_input_index << "] of [" << anf_node->DebugString(2) | |||
| << "] as const tensor, shape: [" << Vector2Str(input_shape) | |||
| << "], value: " << input_desc_json[kJsonKeyValue]; | |||
| @@ -529,9 +530,9 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_json); | |||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||
| MS_LOG(INFO) << "Akg start generate kernel json desc, full scope name is : " << anf_node->fullname_with_scope(); | |||
| MS_LOG(DEBUG) << "Akg start generate kernel json desc, full scope name is : " << anf_node->fullname_with_scope(); | |||
| SetAkgKernelAttrs(anf_node); | |||
| dump_option_.extract_opinfo_from_anfnode = false; | |||
| is_basic_op_ = true; | |||
| if (!GenerateSingleKernelJson(anf_node, kernel_json)) { | |||
| MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed."; | |||
| return false; | |||
| @@ -551,8 +552,8 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "Akg create kernel json desc success, full scope name is : " << anf_node->fullname_with_scope() | |||
| << ", json info name is : " << kernel_name_; | |||
| MS_LOG(DEBUG) << "Akg create kernel json desc success, full scope name is : " << anf_node->fullname_with_scope() | |||
| << ", json info name is : " << kernel_name_; | |||
| return true; | |||
| } | |||
| @@ -613,10 +614,11 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||
| << "]."; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size() | |||
| << "], output_list: [" << input_list.size() << "]."; | |||
| MS_LOG(DEBUG) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size() | |||
| << "], output_list: [" << input_list.size() << "]."; | |||
| std::map<AnfNodePtr, nlohmann::json> node_json_map; | |||
| dump_option_.extract_opinfo_from_anfnode = true; | |||
| is_basic_op_ = false; | |||
| dump_option_.extract_opinfo_from_anfnode = true; // always extract from anfnode for composite ops. | |||
| if (!GenSingleJsons(anf_nodes, &node_json_map)) return false; | |||
| UpdateTensorName(anf_nodes, &node_json_map); | |||
| @@ -144,6 +144,7 @@ class AkgKernelJsonGenerator { | |||
| std::map<std::string, AnfNodePtr> address_node_map_; | |||
| std::map<size_t, std::vector<std::string>> sub_graphs_; | |||
| std::map<size_t, size_t> dim_infos_; | |||
| bool is_basic_op_{false}; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -393,6 +393,14 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP | |||
| broadcast_input_node}; | |||
| auto broadcast_to_node_inner = CreateCNode( | |||
| atomic_clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_node_)}); | |||
| auto device_shape = AnfAlgo::GetOutputDeviceShape(atomic_add_node_, 0); | |||
| dst_shape_vec.clear(); | |||
| if (device_shape.empty()) { | |||
| dst_shape_vec.push_back(1); | |||
| } else { | |||
| std::transform(device_shape.begin(), device_shape.end(), std::back_inserter(dst_shape_vec), SizeToLong); | |||
| } | |||
| SetNodeAttrSafely("shape", MakeValue(dst_shape_vec), broadcast_to_node_inner); | |||
| // Makeup sub-graph. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -36,88 +36,40 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr auto kJsonKeyExpandInfo = "expand_info"; | |||
| #define GET_VALUE_FOR_JSON(JSON, VALUE, VALUE_ELEM, TYPE_NAME, TYPE) \ | |||
| if (VALUE_ELEM->isa<TYPE_NAME>()) { \ | |||
| JSON = GetValue<TYPE>(VALUE); \ | |||
| } | |||
| nlohmann::json ExpandAttrJsonInfo(const CNodePtr &cnode) { | |||
| nlohmann::json attrs_json; | |||
| if (auto prim = GetCNodePrimitive(cnode); prim != nullptr) { | |||
| auto attrs = prim->attrs(); | |||
| for (const auto &[k, v] : attrs) { | |||
| nlohmann::json attr_json; | |||
| MS_LOG(DEBUG) << "attr key is : " << k << " and value type is : " << v->type_name(); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, Int32Imm, int); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, Int64Imm, int64_t); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, UInt32Imm, uint32_t); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, UInt64Imm, uint64_t); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, FP32Imm, float); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, FP64Imm, double); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, BoolImm, bool); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, v, StringImm, std::string); | |||
| if (v->isa<ValueList>() || v->isa<ValueTuple>()) { | |||
| auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value(); | |||
| if (!vec.empty()) { | |||
| MS_LOG(DEBUG) << "value type is : " << vec[0]->type_name(); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], Int32Imm, std::vector<int>); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], Int64Imm, std::vector<int64_t>); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], UInt32Imm, std::vector<uint32_t>); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], UInt64Imm, std::vector<uint64_t>); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], FP32Imm, std::vector<float>); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], FP64Imm, std::vector<double>); | |||
| GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], StringImm, std::vector<std::string>); | |||
| } | |||
| } | |||
| if (!attr_json.empty()) { | |||
| attrs_json.push_back(attr_json); | |||
| } | |||
| } | |||
| } | |||
| return attrs_json; | |||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | |||
| std::unordered_set<PrimitivePtr> expand_ops = { | |||
| prim::kPrimSquare, | |||
| prim::kPrimGeLUGrad, | |||
| #if ENABLE_D | |||
| prim::kPrimTile, | |||
| prim::kPrimSqrtGrad, | |||
| prim::kPrimClipByNormNoDivSum, | |||
| #elif ENABLE_GPU | |||
| prim::kPrimBiasAdd, | |||
| prim::kPrimBiasAddGrad, | |||
| prim::kPrimGeLU, | |||
| prim::kPrimFusedAdam, | |||
| prim::kPrimFusedAdamWeightDecay, | |||
| prim::kPrimReduceMean, | |||
| prim::kPrimMaximumGrad, | |||
| prim::kPrimMinimumGrad, | |||
| prim::kPrimGkDropout, | |||
| prim::kPrimDropoutGrad, | |||
| prim::kPrimSoftmax, | |||
| prim::kPrimLayerNorm, | |||
| prim::kPrimLayerNormGrad, | |||
| #endif | |||
| }; | |||
| return expand_ops; | |||
| } | |||
| } // namespace | |||
| bool ExpandJsonInfo(const CNodePtr &cnode, nlohmann::json *kernel_json) { | |||
| MS_EXCEPTION_IF_NULL(kernel_json); | |||
| if (kernel_json->find(kJsonKeyExpandInfo) != kernel_json->end()) { | |||
| return false; | |||
| } | |||
| nlohmann::json expand_info; | |||
| expand_info[kernel::kJsonKeyAttr] = ExpandAttrJsonInfo(cnode); | |||
| expand_info[kernel::kJsonKeyName] = AnfAlgo::GetCNodeName(cnode); | |||
| expand_info[kernel::kJsonKeyProcess] = kernel::GetProcessorStr(cnode); | |||
| std::vector<nlohmann::json> inputs_info; | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) { | |||
| nlohmann::json input_info; | |||
| input_info[kernel::kJsonKeyFormat] = AnfAlgo::GetInputFormat(cnode, i); | |||
| input_info[kernel::kJsonKeyInferShape] = AnfAlgo::GetPrevNodeOutputInferShape(cnode, i); | |||
| input_info[kernel::kJsonKeyShape] = AnfAlgo::GetInputDeviceShape(cnode, i); | |||
| input_info[kernel::kJsonKeyInferDataType] = | |||
| kernel::TypeId2String(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, i)); | |||
| input_info[kernel::kJsonKeyDataType] = kernel::TypeId2String(AnfAlgo::GetInputDeviceDataType(cnode, i)); | |||
| inputs_info.push_back(input_info); | |||
| } | |||
| expand_info[kernel::kJsonKeyInputDesc] = inputs_info; | |||
| std::vector<nlohmann::json> outputs_info; | |||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(cnode); ++i) { | |||
| nlohmann::json output_info; | |||
| output_info[kernel::kJsonKeyFormat] = AnfAlgo::GetOutputFormat(cnode, i); | |||
| output_info[kernel::kJsonKeyInferShape] = AnfAlgo::GetOutputInferShape(cnode, i); | |||
| output_info[kernel::kJsonKeyShape] = AnfAlgo::GetOutputDeviceShape(cnode, i); | |||
| output_info[kernel::kJsonKeyInferDataType] = kernel::TypeId2String(AnfAlgo::GetOutputInferDataType(cnode, i)); | |||
| output_info[kernel::kJsonKeyDataType] = kernel::TypeId2String(AnfAlgo::GetOutputDeviceDataType(cnode, i)); | |||
| outputs_info.push_back(output_info); | |||
| } | |||
| expand_info[kernel::kJsonKeyOutputDesc] = outputs_info; | |||
| (*kernel_json)[kJsonKeyExpandInfo] = expand_info; | |||
| return true; | |||
| bool GraphKernelExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) { | |||
| DumpOption dump_option; | |||
| dump_option.extract_opinfo_from_anfnode = true; | |||
| kernel::AkgKernelJsonGenerator json_generator(dump_option); | |||
| return json_generator.CollectJson(node, kernel_json); | |||
| } | |||
| } // namespace | |||
| FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { | |||
| nlohmann::json kernel_json; | |||
| @@ -213,33 +165,11 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { | |||
| // replace origin node. | |||
| (void)mng->Replace(node, graph_kernel_node); | |||
| ToPrimitive(AnfAlgo::GetCNodeFuncGraphPtr(graph_kernel_node)); | |||
| changed = true; | |||
| } | |||
| return changed; | |||
| } | |||
| void GraphKernelExpander::ToPrimitive(const FuncGraphPtr &func_graph) const { | |||
| auto todos = TopoSort(func_graph->get_return()); | |||
| std::reverse(todos.begin(), todos.end()); | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| for (const auto &n : todos) { | |||
| auto cnode = n->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| continue; | |||
| } | |||
| auto origin_prim = AnfAlgo::GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(origin_prim); | |||
| if (!origin_prim->isa<PrimitivePy>()) { | |||
| continue; | |||
| } | |||
| cnode->set_input(0, std::make_shared<ValueNode>(std::make_shared<Primitive>(*origin_prim))); | |||
| } | |||
| } | |||
| bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { | |||
| expand_ops_ = GetExpandOps(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -17,6 +17,7 @@ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_ | |||
| #include <memory> | |||
| #include <unordered_set> | |||
| #include <nlohmann/json.hpp> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "ir/func_graph.h" | |||
| @@ -31,7 +32,6 @@ class GraphKernelExpander : public Pass { | |||
| private: | |||
| FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); | |||
| bool DoExpand(const FuncGraphPtr &func_graph); | |||
| void ToPrimitive(const FuncGraphPtr &func_graph) const; | |||
| void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); | |||
| AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph, | |||
| const CNodePtr &node); | |||
| @@ -39,6 +39,7 @@ class GraphKernelExpander : public Pass { | |||
| return std::any_of(expand_ops_.begin(), expand_ops_.end(), | |||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||
| } | |||
| bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json); | |||
| private: | |||
| std::unordered_set<PrimitivePtr> expand_ops_; | |||
| @@ -39,27 +39,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| void DebugDump(const FuncGraphPtr &graph, std::stringstream *buf) { | |||
| (*buf) << "Parameters: \n"; | |||
| const auto ¶meters = graph->parameters(); | |||
| (*buf) << "size: " << parameters.size() << "\n"; | |||
| for (const auto &p : parameters) { | |||
| (*buf) << "\t" << p->DebugString(2) << "\n"; | |||
| } | |||
| (*buf) << "ValueNodes: \n"; | |||
| const auto &value_nodes = graph->value_nodes(); | |||
| (*buf) << "size: " << value_nodes.size() << "\n"; | |||
| for (const auto &v : value_nodes) { | |||
| (*buf) << "\t" << v.first->DebugString(2) << "\n"; | |||
| } | |||
| (*buf) << "CNodes: \n"; | |||
| const auto &all_nodes = graph->nodes(); | |||
| (*buf) << "size: " << all_nodes.size() << "\n"; | |||
| for (const auto &n : all_nodes) { | |||
| (*buf) << "\t" << n->DebugString(2) << "\n"; | |||
| } | |||
| } | |||
| bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) { | |||
| MS_EXCEPTION_IF_NULL(real_outs); | |||
| if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { | |||
| @@ -91,132 +70,6 @@ AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) { | |||
| return out_spec; | |||
| } | |||
| ValueNodePtr ProcessAttrsForCast(const CNodePtr &cnode, const std::string &attr_name) { | |||
| auto dst_type = AnfAlgo::GetNodeAttr<std::string>(cnode, attr_name); | |||
| auto type = TypeIdToType(kernel::DtypeToTypeId(dst_type)); | |||
| auto type_val_node = NewValueNode(type); | |||
| return type_val_node; | |||
| } | |||
| const std::map<std::string, std::function<ValueNodePtr(const CNodePtr &cnode, const std::string &attr_name)>> | |||
| attrs_process_map = { | |||
| {kCastOpName, ProcessAttrsForCast}, | |||
| }; | |||
| ValueNodePtr ProcessAttrValue(const CNodePtr &cnode, const std::string &attr_name) { | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (attrs_process_map.count(op_name) != 0) { | |||
| return attrs_process_map.at(op_name)(cnode, attr_name); | |||
| } | |||
| auto attr_val = AnfAlgo::GetNodeAttr<ValuePtr>(cnode, attr_name); | |||
| auto attr_val_node = NewValueNode(attr_val); | |||
| return attr_val_node; | |||
| } | |||
| AnfNodePtr ConstAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::unordered_set<size_t> &input_attrs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(DEBUG) << "process node: " << cnode->DebugString(2); | |||
| if (input_attrs.empty()) { | |||
| return nullptr; | |||
| } | |||
| auto input_names = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames); | |||
| MS_LOG(DEBUG) << "ori_input_names: " << kernel::Vector2Str(input_names); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| std::vector<std::string> new_input_names; | |||
| const auto &inputs = cnode->inputs(); | |||
| for (size_t i = 0; i < inputs.size() - 1; ++i) { | |||
| new_input_names.push_back(input_names[i]); | |||
| } | |||
| (void)new_inputs.insert(new_inputs.end(), inputs.begin(), inputs.end()); | |||
| bool need_update = false; | |||
| for (size_t i = inputs.size() - 1; i < input_names.size(); ++i) { | |||
| auto attr_name = input_names[i]; | |||
| if (input_attrs.find(i) == input_attrs.end()) { | |||
| MS_LOG(WARNING) << "Other type input between tensors and attrs, name: " << attr_name | |||
| << ", node: " << cnode->DebugString(2); | |||
| new_input_names.push_back(attr_name); | |||
| continue; | |||
| } | |||
| if (!AnfAlgo::HasNodeAttr(attr_name, cnode)) { | |||
| MS_LOG(EXCEPTION) << "Attr: " << attr_name << " not found in node: " << cnode->DebugString(2); | |||
| } | |||
| // Hardcode. It should convert attrs value according to format, like op ReduceSum. | |||
| auto attr_val_node = ProcessAttrValue(cnode, attr_name); | |||
| new_inputs.push_back(attr_val_node); | |||
| new_input_names.push_back(attr_name); | |||
| need_update = true; | |||
| MS_LOG(DEBUG) << "convert attr: " << attr_name << " to input, value: " << attr_val_node; | |||
| } | |||
| MS_LOG(DEBUG) << "new_input_names: " << kernel::Vector2Str(new_input_names); | |||
| if (!need_update) { | |||
| return nullptr; | |||
| } | |||
| auto new_cnode = func_graph->NewCNode(new_inputs); | |||
| // we do not modify abstract and kernel info. | |||
| new_cnode->set_abstract(cnode->abstract()); | |||
| new_cnode->set_kernel_info(cnode->kernel_info_ptr()); | |||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(new_input_names), new_cnode); | |||
| return new_cnode; | |||
| } | |||
| AnfNodePtr DeleteAttrInInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::unordered_set<size_t> &input_attrs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(DEBUG) << "process node: " << cnode->DebugString(2); | |||
| if (input_attrs.empty()) { | |||
| return nullptr; | |||
| } | |||
| auto input_names = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames); | |||
| MS_LOG(DEBUG) << "ori_input_names: " << kernel::Vector2Str(input_names); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| std::vector<std::string> new_input_names; | |||
| const auto &inputs = cnode->inputs(); | |||
| new_inputs.push_back(inputs[0]); | |||
| bool need_update = false; | |||
| for (size_t i = 0; i < inputs.size() - 1; ++i) { | |||
| auto input_node = inputs[i + 1]; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| // The attrs counts from 0 | |||
| if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) { | |||
| auto value_node = input_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| MS_LOG(DEBUG) << "delete attr input: " << i << " of node: " << cnode->DebugString(2); | |||
| if (i >= input_names.size()) { | |||
| MS_LOG(EXCEPTION) << "Index " << i << " is larger than input names size: " << input_names.size(); | |||
| } | |||
| need_update = true; | |||
| } else { | |||
| new_inputs.push_back(input_node); | |||
| if (i < input_names.size()) { | |||
| new_input_names.push_back(input_names[i]); | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "new_input_names: " << kernel::Vector2Str(new_input_names); | |||
| if (!need_update) { | |||
| return nullptr; | |||
| } | |||
| auto new_cnode = func_graph->NewCNode(new_inputs); | |||
| // we do not modify abstract and kernel info. | |||
| new_cnode->set_abstract(cnode->abstract()); | |||
| new_cnode->set_kernel_info(cnode->kernel_info_ptr()); | |||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(new_input_names), new_cnode); | |||
| return new_cnode; | |||
| } | |||
| AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { | |||
| AnfNodePtrList outs; | |||
| auto out_node = fg->output(); | |||
| @@ -396,59 +249,6 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const | |||
| AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get()); | |||
| } | |||
| void ConstAttrToInput(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| std::vector<AnfNodePtr> todos; | |||
| kernel::GetValidKernelNodes(func_graph, &todos); | |||
| for (const auto &node : todos) { | |||
| ConstInputToAttrInfoRegister reg; | |||
| if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(node), ®)) { | |||
| continue; | |||
| } | |||
| auto new_node = ConstAttrToInput(func_graph, node->cast<CNodePtr>(), reg.GetConstInputAttrInfo()); | |||
| if (new_node != nullptr && new_node != node) { | |||
| mng->Replace(node, new_node); | |||
| } | |||
| } | |||
| } | |||
| void DeleteAttrInInput(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| std::vector<AnfNodePtr> todos; | |||
| kernel::GetValidKernelNodes(func_graph, &todos); | |||
| for (const auto &node : todos) { | |||
| ConstInputToAttrInfoRegister reg; | |||
| if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(node), ®)) { | |||
| continue; | |||
| } | |||
| auto new_node = DeleteAttrInInput(func_graph, node->cast<CNodePtr>(), reg.GetConstInputAttrInfo()); | |||
| if (new_node != nullptr && new_node != node) { | |||
| mng->Replace(node, new_node); | |||
| } | |||
| } | |||
| } | |||
| AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) { | |||
| AnfNodePtrList res; | |||
| if (outs.size() <= 1) { | |||
| return outs; | |||
| } | |||
| for (auto out : outs) { | |||
| AnfNodePtrList real_outs; | |||
| if (IsMakeTupleOut(out, &real_outs)) { | |||
| res.insert(res.end(), real_outs.begin(), real_outs.end()); | |||
| continue; | |||
| } | |||
| res.push_back(out); | |||
| } | |||
| return res; | |||
| } | |||
| AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | |||
| const AnfNodePtrList &outputs) { | |||
| auto func_node = NewValueNode(fg); | |||
| @@ -661,68 +461,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||
| MS_LOG(ERROR) << "Akg decode json to graph failed."; | |||
| return nullptr; | |||
| } | |||
| pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>(); | |||
| auto mng = resource->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| mng->AddFuncGraph(fg); | |||
| ConstAttrToInput(fg); | |||
| std::stringstream buf; | |||
| buf << "===================== graph after ConstAttrToInput " << fg->ToString() << " =====================\n"; | |||
| DebugDump(fg, &buf); | |||
| MS_LOG(DEBUG) << buf.str(); | |||
| // Do infer and specialize. | |||
| AbstractBasePtrList args_spec_list; | |||
| std::for_each(inputs.begin(), inputs.end(), | |||
| [&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); }); | |||
| auto infer_fg = pipeline::Renormalize(resource, fg, args_spec_list); | |||
| if (infer_fg == nullptr) { | |||
| MS_LOG(ERROR) << "Infer decoded graph failed."; | |||
| return nullptr; | |||
| } | |||
| buf.str(""); | |||
| buf << "===================== graph after Renormalize " << infer_fg->ToString() << " =====================\n"; | |||
| DebugDump(infer_fg, &buf); | |||
| MS_LOG(DEBUG) << buf.str(); | |||
| // delete no use inputs(attrs), like op ReduceSum(axis). | |||
| DeleteAttrInInput(infer_fg); | |||
| buf.str(""); | |||
| buf << "===================== graph after DeleteAttrInInput " << infer_fg->ToString() << " =====================\n"; | |||
| DebugDump(infer_fg, &buf); | |||
| MS_LOG(DEBUG) << buf.str(); | |||
| // clone a new graph. | |||
| auto new_fg = TransformableClone(infer_fg, std::make_shared<TraceTransform>("akg_decode")); | |||
| return new_fg; | |||
| } | |||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | |||
| std::unordered_set<PrimitivePtr> expand_ops = { | |||
| prim::kPrimSquare, | |||
| prim::kPrimGeLUGrad, | |||
| #if ENABLE_D | |||
| prim::kPrimTile, | |||
| prim::kPrimSqrtGrad, | |||
| prim::kPrimClipByNormNoDivSum, | |||
| #elif ENABLE_GPU | |||
| prim::kPrimBiasAdd, | |||
| prim::kPrimBiasAddGrad, | |||
| prim::kPrimGeLU, | |||
| prim::kPrimFusedAdam, | |||
| prim::kPrimFusedAdamWeightDecay, | |||
| prim::kPrimReduceMean, | |||
| prim::kPrimMaximumGrad, | |||
| prim::kPrimMinimumGrad, | |||
| prim::kPrimGkDropout, | |||
| prim::kPrimDropoutGrad, | |||
| prim::kPrimSoftmax, | |||
| prim::kPrimLayerNorm, | |||
| prim::kPrimLayerNormGrad, | |||
| #endif | |||
| }; | |||
| return expand_ops; | |||
| return fg; | |||
| } | |||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) { | |||
| @@ -61,7 +61,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph( | |||
| AnfNodePtrList *src_outputs = nullptr); | |||
| void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | |||
| const AnfNodePtrList &outputs, kernel::Processor processor); | |||
| AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs); | |||
| AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | |||
| const AnfNodePtrList &outputs); | |||
| void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, | |||
| @@ -74,7 +73,6 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n | |||
| std::map<std::string, AnfNodePtr> *address_node_map); | |||
| bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); | |||
| FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs); | |||
| std::unordered_set<PrimitivePtr> GetExpandOps(); | |||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | |||
| std::vector<PrimitivePtr> GetFusibleOpList(); | |||
| bool IsBasicFuseOp(const AnfNodePtr &node); | |||