From: @liangzhibo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qhtags/v1.2.0-rc1
| @@ -36,24 +36,24 @@ def expand_biasadd(expand_info): | |||||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | 'ExpandDims', [input_y], attrs={'axis': 1}) | ||||
| input_y_expand = graph_builder.emit( | input_y_expand = graph_builder.emit( | ||||
| 'ExpandDims', [input_y_expand], attrs={'axis': 2}) | 'ExpandDims', [input_y_expand], attrs={'axis': 2}) | ||||
| result = graph_builder.emit('TensorAdd', [input_x, input_y_expand]) | |||||
| result = graph_builder.emit('Add', [input_x, input_y_expand]) | |||||
| elif input_x.data_format == "DefaultFormat": | elif input_x.data_format == "DefaultFormat": | ||||
| if len(input_x.shape) == 2: | if len(input_x.shape) == 2: | ||||
| result = graph_builder.emit('TensorAdd', [input_x, input_y]) | |||||
| result = graph_builder.emit('Add', [input_x, input_y]) | |||||
| elif len(input_x.shape) == 3: | elif len(input_x.shape) == 3: | ||||
| input_y_expand = graph_builder.emit( | input_y_expand = graph_builder.emit( | ||||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | 'ExpandDims', [input_y], attrs={'axis': 1}) | ||||
| result = graph_builder.emit( | result = graph_builder.emit( | ||||
| 'TensorAdd', [input_x, input_y_expand]) | |||||
| 'Add', [input_x, input_y_expand]) | |||||
| else: | else: | ||||
| input_y_expand = graph_builder.emit( | input_y_expand = graph_builder.emit( | ||||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | 'ExpandDims', [input_y], attrs={'axis': 1}) | ||||
| input_y_expand = graph_builder.emit( | input_y_expand = graph_builder.emit( | ||||
| 'ExpandDims', [input_y_expand], attrs={'axis': 2}) | 'ExpandDims', [input_y_expand], attrs={'axis': 2}) | ||||
| result = graph_builder.emit( | result = graph_builder.emit( | ||||
| 'TensorAdd', [input_x, input_y_expand]) | |||||
| 'Add', [input_x, input_y_expand]) | |||||
| else: | else: | ||||
| result = graph_builder.emit('TensorAdd', [input_x, input_y]) | |||||
| result = graph_builder.emit('Add', [input_x, input_y]) | |||||
| # set graph output. | # set graph output. | ||||
| graph_scope.set_output(result) | graph_scope.set_output(result) | ||||
| @@ -49,13 +49,13 @@ def expand_fusedadam(expand_info): | |||||
| # compute result | # compute result | ||||
| beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | ||||
| one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | ||||
| next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | |||||
| next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | |||||
| beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) | beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) | ||||
| grad_square = graph_builder.emit('Mul', [gradient, gradient]) | grad_square = graph_builder.emit('Mul', [gradient, gradient]) | ||||
| one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) | one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) | ||||
| next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) | |||||
| next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) | |||||
| sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) | sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) | ||||
| sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps]) | |||||
| sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps]) | |||||
| update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) | update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) | ||||
| update_with_lr = graph_builder.emit('Mul', [lr, update]) | update_with_lr = graph_builder.emit('Mul', [lr, update]) | ||||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | next_para = graph_builder.emit('Sub', [param, update_with_lr]) | ||||
| @@ -52,16 +52,16 @@ def expand_fusedadamweightdecay(expand_info): | |||||
| # compute result | # compute result | ||||
| beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) | ||||
| one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) | ||||
| next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | |||||
| next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) | |||||
| beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) | beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) | ||||
| grad_square = graph_builder.emit('Mul', [gradient, gradient]) | grad_square = graph_builder.emit('Mul', [gradient, gradient]) | ||||
| one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) | one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) | ||||
| next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) | |||||
| next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) | |||||
| sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) | sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) | ||||
| sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps]) | |||||
| sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps]) | |||||
| update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) | update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) | ||||
| param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param]) | param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param]) | ||||
| update = graph_builder.emit('TensorAdd', [update, param_with_weight_decay]) | |||||
| update = graph_builder.emit('Add', [update, param_with_weight_decay]) | |||||
| update_with_lr = graph_builder.emit('Mul', [lr, update]) | update_with_lr = graph_builder.emit('Mul', [lr, update]) | ||||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | next_para = graph_builder.emit('Sub', [param, update_with_lr]) | ||||
| @@ -42,7 +42,7 @@ def expand_gelu(expand_info): | |||||
| pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) | pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) | ||||
| const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format']) | const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format']) | ||||
| mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) | mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) | ||||
| tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1]) | |||||
| tanh_res = graph_builder.emit('Add', [input_x, mul_1]) | |||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value( | const_csvalue_sqrt_two_div_pi = graph_builder.value( | ||||
| tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format']) | tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format']) | ||||
| y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) | y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) | ||||
| @@ -51,7 +51,7 @@ def expand_gelu(expand_info): | |||||
| tanh_y = graph_builder.emit('Tanh', [y]) | tanh_y = graph_builder.emit('Tanh', [y]) | ||||
| const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format']) | const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format']) | ||||
| const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format']) | const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format']) | ||||
| tanh_y_add_one = graph_builder.emit('TensorAdd', [tanh_y, const_one]) | |||||
| tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) | |||||
| mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | ||||
| result = graph_builder.emit('Mul', [const_half, mul_x]) | result = graph_builder.emit('Mul', [const_half, mul_x]) | ||||
| @@ -55,18 +55,18 @@ def expand_gelugrad(expand_info): | |||||
| # cal mul_right | # cal mul_right | ||||
| mul_double = graph_builder.emit('Mul', [input_x, input_x]) | mul_double = graph_builder.emit('Mul', [input_x, input_x]) | ||||
| mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double]) | mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double]) | ||||
| mul_add_one = graph_builder.emit('TensorAdd', [const_one, mul_double_mul_tri]) | |||||
| mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri]) | |||||
| mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one]) | mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one]) | ||||
| # cal tanh_para | # cal tanh_para | ||||
| mul_triple = graph_builder.emit('Mul', [input_x, mul_double]) | mul_triple = graph_builder.emit('Mul', [input_x, mul_double]) | ||||
| mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple]) | mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple]) | ||||
| mul_add_x = graph_builder.emit('TensorAdd', [input_x, mul_triple_mul_csvalue]) | |||||
| mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue]) | |||||
| tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x]) | tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x]) | ||||
| # cal 0.5 * (1.0 + tanh(tahn_para)) | # cal 0.5 * (1.0 + tanh(tahn_para)) | ||||
| tanh_res = graph_builder.emit('Tanh', [tanh_para]) | tanh_res = graph_builder.emit('Tanh', [tanh_para]) | ||||
| tanh_res_add_one = graph_builder.emit('TensorAdd', [const_one, tanh_res]) | |||||
| tanh_res_add_one = graph_builder.emit('Add', [const_one, tanh_res]) | |||||
| half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one]) | half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one]) | ||||
| # cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right | # cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right | ||||
| @@ -77,7 +77,7 @@ def expand_gelugrad(expand_info): | |||||
| mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right]) | mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right]) | ||||
| # cal result | # cal result | ||||
| result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final]) | |||||
| result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final]) | |||||
| result = graph_builder.emit('Mul', [input_dy, result_tmp]) | result = graph_builder.emit('Mul', [input_dy, result_tmp]) | ||||
| # set graph output. | # set graph output. | ||||
| @@ -68,13 +68,13 @@ def expand_layernorm(expand_info): | |||||
| # Calculate normalize | # Calculate normalize | ||||
| normalize_sub = graph_builder.emit('Sub', [input_x, mean]) | normalize_sub = graph_builder.emit('Sub', [input_x, mean]) | ||||
| epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format) | epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format) | ||||
| normalize_add = graph_builder.emit('TensorAdd', [variance, epsilon_v]) | |||||
| normalize_add = graph_builder.emit('Add', [variance, epsilon_v]) | |||||
| normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add]) | normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add]) | ||||
| normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) | normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) | ||||
| # Calculate scale and translate | # Calculate scale and translate | ||||
| scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul]) | scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul]) | ||||
| res = graph_builder.emit('TensorAdd', [scale_mul, input_beta]) | |||||
| res = graph_builder.emit('Add', [scale_mul, input_beta]) | |||||
| # set graph output. | # set graph output. | ||||
| graph_scope.set_output(res, mean, variance) | graph_scope.set_output(res, mean, variance) | ||||
| @@ -66,7 +66,7 @@ def expand_layernormgrad(expand_info): | |||||
| mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format) | mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format) | ||||
| # cal dg db | # cal dg db | ||||
| var_eps = graph_builder.emit('TensorAdd', [variance, eps]) | |||||
| var_eps = graph_builder.emit('Add', [variance, eps]) | |||||
| sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps]) | sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps]) | ||||
| rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps]) | rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps]) | ||||
| x_sub_mean = graph_builder.emit('Sub', [x, mean]) | x_sub_mean = graph_builder.emit('Sub', [x, mean]) | ||||
| @@ -100,10 +100,10 @@ def expand_layernormgrad(expand_info): | |||||
| neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2]) | neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2]) | ||||
| sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3]) | sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3]) | ||||
| mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3]) | mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3]) | ||||
| add_tmp = graph_builder.emit('TensorAdd', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3]) | |||||
| add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3]) | |||||
| dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof]) | dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof]) | ||||
| dx_tmp = graph_builder.emit('TensorAdd', [dx_1, dx_2]) | |||||
| dx = graph_builder.emit('TensorAdd', [dx_tmp, dx_3]) | |||||
| dx_tmp = graph_builder.emit('Add', [dx_1, dx_2]) | |||||
| dx = graph_builder.emit('Add', [dx_tmp, dx_3]) | |||||
| # set graph output. | # set graph output. | ||||
| graph_scope.set_output(dx, dg, db) | graph_scope.set_output(dx, dg, db) | ||||
| @@ -131,7 +131,7 @@ class PrimLib: | |||||
| ] | ] | ||||
| primtives = { | primtives = { | ||||
| 'TensorAdd': Prim(ELEMWISE), | |||||
| 'Add': Prim(ELEMWISE), | |||||
| 'Abs': Prim(ELEMWISE), | 'Abs': Prim(ELEMWISE), | ||||
| 'Neg': Prim(ELEMWISE), | 'Neg': Prim(ELEMWISE), | ||||
| 'Mul': Prim(ELEMWISE), | 'Mul': Prim(ELEMWISE), | ||||
| @@ -238,7 +238,7 @@ void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out, | |||||
| void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | ||||
| if (kernel_name == prim::kPrimTensorAdd->name()) { | |||||
| if (kernel_name == prim::kPrimAdd->name()) { | |||||
| operate_type_ = ADD; | operate_type_ = ADD; | ||||
| } else if (kernel_name == prim::kPrimSub->name()) { | } else if (kernel_name == prim::kPrimSub->name()) { | ||||
| operate_type_ = SUB; | operate_type_ = SUB; | ||||
| @@ -37,8 +37,7 @@ class TensorAddCPUKernel : public MKLCPUKernel { | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL( | MS_REG_CPU_KERNEL( | ||||
| TensorAdd, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| TensorAddCPUKernel); | TensorAddCPUKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -51,8 +51,7 @@ MS_REG_GPU_KERNEL_ONE( | |||||
| Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| BroadcastOpGpuKernel, float) | BroadcastOpGpuKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| TensorAdd, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| BroadcastOpGpuKernel, float) | BroadcastOpGpuKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| FloorDiv, | FloorDiv, | ||||
| @@ -103,8 +102,7 @@ MS_REG_GPU_KERNEL_ONE( | |||||
| Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| BroadcastOpGpuKernel, half) | BroadcastOpGpuKernel, half) | ||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| TensorAdd, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| Add, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| BroadcastOpGpuKernel, half) | BroadcastOpGpuKernel, half) | ||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| FloorDiv, | FloorDiv, | ||||
| @@ -133,7 +131,7 @@ MS_REG_GPU_KERNEL_ONE( | |||||
| Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | ||||
| BroadcastOpGpuKernel, int) | BroadcastOpGpuKernel, int) | ||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| BroadcastOpGpuKernel, int) | BroadcastOpGpuKernel, int) | ||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| @@ -171,7 +169,7 @@ MS_REG_GPU_KERNEL_ONE( | |||||
| Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), | Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), | ||||
| BroadcastOpGpuKernel, int64_t) | BroadcastOpGpuKernel, int64_t) | ||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||||
| Add, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||||
| BroadcastOpGpuKernel, int64_t) | BroadcastOpGpuKernel, int64_t) | ||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | ||||
| @@ -145,7 +145,7 @@ class BroadcastOpGpuKernel : public GpuKernel { | |||||
| static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = { | static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = { | ||||
| {"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, | {"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, | ||||
| {"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, | {"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, | ||||
| {"TensorAdd", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, | |||||
| {"Add", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, | |||||
| {"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN}, | {"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN}, | ||||
| }; | }; | ||||
| @@ -1063,7 +1063,7 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i | |||||
| std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { | std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { | ||||
| static std::map<std::string, std::string> buffer_fussion_op_map = { | static std::map<std::string, std::string> buffer_fussion_op_map = { | ||||
| {parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}, {parallel::TENSOR_ADD, parallel::ADD}}; | |||||
| {parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}}; | |||||
| string result = origin_type; | string result = origin_type; | ||||
| auto iter = buffer_fussion_op_map.find(origin_type); | auto iter = buffer_fussion_op_map.find(origin_type); | ||||
| if (iter != buffer_fussion_op_map.end()) { | if (iter != buffer_fussion_op_map.end()) { | ||||
| @@ -99,7 +99,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K | |||||
| AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) { | AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) { | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | MS_EXCEPTION_IF_NULL(eltwise_input); | ||||
| if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { | |||||
| if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) { | |||||
| MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); | MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); | ||||
| } | } | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ const BaseRef AdamApplyOneFusion::DefinePattern() const { | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})}); | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | ||||
| } | } | ||||
| @@ -41,7 +41,7 @@ const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const { | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})}); | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | ||||
| } | } | ||||
| @@ -54,7 +54,7 @@ const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const { | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})}); | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | ||||
| } | } | ||||
| @@ -67,7 +67,7 @@ const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const { | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})}); | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | ||||
| } | } | ||||
| @@ -80,7 +80,7 @@ const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const { | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})}); | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | ||||
| } | } | ||||
| @@ -94,7 +94,7 @@ const BaseRef AdamApplyOneAssignFusion::DefinePattern() const { | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})}); | |||||
| VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | ||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | ||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | ||||
| @@ -114,7 +114,7 @@ const BaseRef AdamApplyOneAssignCond1Fusion::DefinePattern() const { | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})}); | |||||
| VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | ||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | ||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | ||||
| @@ -134,7 +134,7 @@ const BaseRef AdamApplyOneAssignCond2Fusion::DefinePattern() const { | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})}); | |||||
| VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | ||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | ||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | ||||
| @@ -154,7 +154,7 @@ const BaseRef AdamApplyOneAssignCond3Fusion::DefinePattern() const { | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})}); | |||||
| VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | ||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | ||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | ||||
| @@ -174,7 +174,7 @@ const BaseRef AdamApplyOneAssignCond4Fusion::DefinePattern() const { | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})}); | |||||
| VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | ||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | ||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | ||||
| @@ -38,8 +38,8 @@ class AdamApplyOneFusion : public PatternProcessPass { | |||||
| mul_x_input_vars_.push_back(std::make_shared<Var>()); | mul_x_input_vars_.push_back(std::make_shared<Var>()); | ||||
| } | } | ||||
| add2_y_ = std::make_shared<Var>(); | add2_y_ = std::make_shared<Var>(); | ||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); | |||||
| sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name())); | sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name())); | ||||
| } | } | ||||
| @@ -59,10 +59,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const { | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | ||||
| VectorRef add1({add1_var_, mul2, mul3}); | VectorRef add1({add1_var_, mul2, mul3}); | ||||
| VectorRef sqrt0({sqrt, add1}); | VectorRef sqrt0({sqrt, add1}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); | |||||
| VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | ||||
| VectorRef real_div0({real_div, add0, add2}); | VectorRef real_div0({real_div, add0, add2}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef add3({prim::kPrimAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, input4_, add3}); | VectorRef mul5({prim::kPrimMul, input4_, add3}); | ||||
| VectorRef sub0({prim::kPrimSub, input3_, mul5}); | VectorRef sub0({prim::kPrimSub, input3_, mul5}); | ||||
| return sub0; | return sub0; | ||||
| @@ -79,10 +79,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const { | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | ||||
| VectorRef add1({add1_var_, mul2, mul3}); | VectorRef add1({add1_var_, mul2, mul3}); | ||||
| VectorRef sqrt0({sqrt, add1}); | VectorRef sqrt0({sqrt, add1}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||||
| VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_}); | |||||
| VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); | VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); | ||||
| VectorRef real_div0({real_div, add0, add2}); | VectorRef real_div0({real_div, add0, add2}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef add3({prim::kPrimAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | VectorRef mul5({prim::kPrimMul, add3, input4_}); | ||||
| VectorRef sub0({prim::kPrimSub, input3_, mul5}); | VectorRef sub0({prim::kPrimSub, input3_, mul5}); | ||||
| return sub0; | return sub0; | ||||
| @@ -99,10 +99,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const { | |||||
| VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); | VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); | ||||
| VectorRef add1({add1_var_, mul2, mul3}); | VectorRef add1({add1_var_, mul2, mul3}); | ||||
| VectorRef sqrt0({sqrt, add1}); | VectorRef sqrt0({sqrt, add1}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||||
| VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | ||||
| VectorRef real_div0({real_div, add0, add2}); | VectorRef real_div0({real_div, add0, add2}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef add3({prim::kPrimAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | VectorRef mul5({prim::kPrimMul, add3, input4_}); | ||||
| VectorRef sub0({prim::kPrimSub, input3_, mul5}); | VectorRef sub0({prim::kPrimSub, input3_, mul5}); | ||||
| return sub0; | return sub0; | ||||
| @@ -119,10 +119,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const { | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | ||||
| VectorRef add1({add1_var_, mul2, mul3}); | VectorRef add1({add1_var_, mul2, mul3}); | ||||
| VectorRef sqrt0({sqrt, add1}); | VectorRef sqrt0({sqrt, add1}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); | |||||
| VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | ||||
| VectorRef real_div0({real_div, add0, add2}); | VectorRef real_div0({real_div, add0, add2}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef add3({prim::kPrimAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | VectorRef mul5({prim::kPrimMul, add3, input4_}); | ||||
| VectorRef sub0({prim::kPrimSub, input3_, mul5}); | VectorRef sub0({prim::kPrimSub, input3_, mul5}); | ||||
| return sub0; | return sub0; | ||||
| @@ -139,10 +139,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const { | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | ||||
| VectorRef add1({add1_var_, mul2, mul3}); | VectorRef add1({add1_var_, mul2, mul3}); | ||||
| VectorRef sqrt0({sqrt, add1}); | VectorRef sqrt0({sqrt, add1}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||||
| VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | ||||
| VectorRef real_div0({real_div, add0, add2}); | VectorRef real_div0({real_div, add0, add2}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef add3({prim::kPrimAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | VectorRef mul5({prim::kPrimMul, add3, input4_}); | ||||
| VectorRef sub0({prim::kPrimSub, input3_, mul5}); | VectorRef sub0({prim::kPrimSub, input3_, mul5}); | ||||
| return sub0; | return sub0; | ||||
| @@ -159,10 +159,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond1::DefinePattern() const { | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | ||||
| VectorRef add1({add1_var_, mul2, mul3}); | VectorRef add1({add1_var_, mul2, mul3}); | ||||
| VectorRef sqrt0({sqrt, add1}); | VectorRef sqrt0({sqrt, add1}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); | |||||
| VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | ||||
| VectorRef real_div0({real_div, add0, add2}); | VectorRef real_div0({real_div, add0, add2}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef add3({prim::kPrimAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, input4_, add3}); | VectorRef mul5({prim::kPrimMul, input4_, add3}); | ||||
| VectorRef sub0({sub0_var_, input3_, mul5}); | VectorRef sub0({sub0_var_, input3_, mul5}); | ||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | ||||
| @@ -184,10 +184,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond2::DefinePattern() const { | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | ||||
| VectorRef add1({add1_var_, mul2, mul3}); | VectorRef add1({add1_var_, mul2, mul3}); | ||||
| VectorRef sqrt0({sqrt, add1}); | VectorRef sqrt0({sqrt, add1}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||||
| VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_}); | |||||
| VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); | VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); | ||||
| VectorRef real_div0({real_div, add0, add2}); | VectorRef real_div0({real_div, add0, add2}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef add3({prim::kPrimAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | VectorRef mul5({prim::kPrimMul, add3, input4_}); | ||||
| VectorRef sub0({sub0_var_, input3_, mul5}); | VectorRef sub0({sub0_var_, input3_, mul5}); | ||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | ||||
| @@ -209,10 +209,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond3::DefinePattern() const { | |||||
| VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); | VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); | ||||
| VectorRef add1({add1_var_, mul2, mul3}); | VectorRef add1({add1_var_, mul2, mul3}); | ||||
| VectorRef sqrt0({sqrt, add1}); | VectorRef sqrt0({sqrt, add1}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||||
| VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | ||||
| VectorRef real_div0({real_div, add0, add2}); | VectorRef real_div0({real_div, add0, add2}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef add3({prim::kPrimAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | VectorRef mul5({prim::kPrimMul, add3, input4_}); | ||||
| VectorRef sub0({sub0_var_, input3_, mul5}); | VectorRef sub0({sub0_var_, input3_, mul5}); | ||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | ||||
| @@ -234,10 +234,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond4::DefinePattern() const { | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | ||||
| VectorRef add1({add1_var_, mul2, mul3}); | VectorRef add1({add1_var_, mul2, mul3}); | ||||
| VectorRef sqrt0({sqrt, add1}); | VectorRef sqrt0({sqrt, add1}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); | |||||
| VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | ||||
| VectorRef real_div0({real_div, add0, add2}); | VectorRef real_div0({real_div, add0, add2}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef add3({prim::kPrimAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | VectorRef mul5({prim::kPrimMul, add3, input4_}); | ||||
| VectorRef sub0({sub0_var_, input3_, mul5}); | VectorRef sub0({sub0_var_, input3_, mul5}); | ||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | ||||
| @@ -259,10 +259,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond5::DefinePattern() const { | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | ||||
| VectorRef add1({add1_var_, mul2, mul3}); | VectorRef add1({add1_var_, mul2, mul3}); | ||||
| VectorRef sqrt0({sqrt, add1}); | VectorRef sqrt0({sqrt, add1}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||||
| VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | ||||
| VectorRef real_div0({real_div, add0, add2}); | VectorRef real_div0({real_div, add0, add2}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef add3({prim::kPrimAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | VectorRef mul5({prim::kPrimMul, add3, input4_}); | ||||
| VectorRef sub0({sub0_var_, input3_, mul5}); | VectorRef sub0({sub0_var_, input3_, mul5}); | ||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | ||||
| @@ -38,8 +38,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { | |||||
| mul3_x_ = std::make_shared<Var>(); | mul3_x_ = std::make_shared<Var>(); | ||||
| mul4_x_ = std::make_shared<Var>(); | mul4_x_ = std::make_shared<Var>(); | ||||
| add2_y_ = std::make_shared<Var>(); | add2_y_ = std::make_shared<Var>(); | ||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); | |||||
| sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name())); | sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name())); | ||||
| } | } | ||||
| ~AdamApplyOneWithDecayRule() override = default; | ~AdamApplyOneWithDecayRule() override = default; | ||||
| @@ -130,11 +130,11 @@ const BaseRef LambNextMVRuleCond1::DefinePattern() const { | |||||
| auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | ||||
| auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | ||||
| auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); | |||||
| auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1}); | |||||
| auto sqrt0 = VectorRef({prim_rsqrt, add2}); | auto sqrt0 = VectorRef({prim_rsqrt, add2}); | ||||
| auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); | auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); | ||||
| return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||||
| return VectorRef({prim::kPrimAdd, mul4, real_div2}); | |||||
| } | } | ||||
| BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const { | BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const { | ||||
| @@ -147,7 +147,7 @@ BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const { | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | ||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | ||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimAdd, add2_y_, sqrt1}); | |||||
| VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | ||||
| return real_div4; | return real_div4; | ||||
| } | } | ||||
| @@ -166,11 +166,11 @@ const BaseRef LambNextMVRuleCond2::DefinePattern() const { | |||||
| auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | ||||
| auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | ||||
| auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); | |||||
| auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1}); | |||||
| auto sqrt0 = VectorRef({prim_rsqrt, add2}); | auto sqrt0 = VectorRef({prim_rsqrt, add2}); | ||||
| auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); | auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); | ||||
| return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||||
| return VectorRef({prim::kPrimAdd, mul4, real_div2}); | |||||
| } | } | ||||
| BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const { | BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const { | ||||
| @@ -183,7 +183,7 @@ BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const { | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | ||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | ||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_}); | |||||
| VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | ||||
| return real_div4; | return real_div4; | ||||
| } | } | ||||
| @@ -202,11 +202,11 @@ const BaseRef LambNextMVRuleCond3::DefinePattern() const { | |||||
| auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | ||||
| auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | ||||
| auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); | |||||
| auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_}); | |||||
| auto sqrt0 = VectorRef({prim_rsqrt, add2}); | auto sqrt0 = VectorRef({prim_rsqrt, add2}); | ||||
| auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); | auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); | ||||
| return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||||
| return VectorRef({prim::kPrimAdd, mul4, real_div2}); | |||||
| } | } | ||||
| BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const { | BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const { | ||||
| @@ -219,7 +219,7 @@ BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const { | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | ||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | ||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_}); | |||||
| VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | ||||
| return real_div4; | return real_div4; | ||||
| } | } | ||||
| @@ -238,11 +238,11 @@ const BaseRef LambNextMVRuleCond4::DefinePattern() const { | |||||
| auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | ||||
| auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | ||||
| auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); | |||||
| auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_}); | |||||
| auto sqrt0 = VectorRef({prim_rsqrt, add2}); | auto sqrt0 = VectorRef({prim_rsqrt, add2}); | ||||
| auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0}); | auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0}); | ||||
| return VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); | |||||
| return VectorRef({prim::kPrimAdd, real_div2, mul4}); | |||||
| } | } | ||||
| BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { | BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { | ||||
| @@ -255,7 +255,7 @@ BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | ||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | ||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_}); | |||||
| VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | ||||
| return real_div4; | return real_div4; | ||||
| } | } | ||||
| @@ -49,8 +49,8 @@ class LambNextMVRule : public MultipleOutputPatternProcessPass { | |||||
| real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | ||||
| real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | ||||
| real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name())); | real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name())); | ||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); | |||||
| } | } | ||||
| ~LambNextMVRule() override = default; | ~LambNextMVRule() override = default; | ||||
| const BaseRef DefinePattern() const override = 0; | const BaseRef DefinePattern() const override = 0; | ||||
| @@ -124,10 +124,10 @@ BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const { | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | ||||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | VectorRef mul4 = VectorRef({mul4_var_, Zs}); | ||||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); | |||||
| VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1}); | |||||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | ||||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | ||||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||||
| VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2}); | |||||
| return add3; | return add3; | ||||
| } | } | ||||
| @@ -141,14 +141,14 @@ const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const { | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | ||||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | ||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | ||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); | ||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | ||||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | ||||
| VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | ||||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); | |||||
| VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4}); | |||||
| return add5; | return add5; | ||||
| } | } | ||||
| @@ -165,10 +165,10 @@ BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const { | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | ||||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | VectorRef mul4 = VectorRef({mul4_var_, Zs}); | ||||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); | |||||
| VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1}); | |||||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | ||||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | ||||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||||
| VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2}); | |||||
| return add3; | return add3; | ||||
| } | } | ||||
| @@ -182,14 +182,14 @@ const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const { | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | ||||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | ||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | ||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimAdd, constant_add2_y_, sqrt1}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); | ||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | ||||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | ||||
| VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | ||||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); | |||||
| VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4}); | |||||
| return add5; | return add5; | ||||
| } | } | ||||
| @@ -206,10 +206,10 @@ BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const { | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | ||||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | VectorRef mul4 = VectorRef({mul4_var_, Zs}); | ||||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); | |||||
| VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_}); | |||||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | ||||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | ||||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||||
| VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2}); | |||||
| return add3; | return add3; | ||||
| } | } | ||||
| @@ -223,14 +223,14 @@ const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const { | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | ||||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | ||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | ||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); | ||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | ||||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | ||||
| VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]}); | VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]}); | ||||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); | |||||
| VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4}); | |||||
| return add5; | return add5; | ||||
| } | } | ||||
| @@ -248,10 +248,10 @@ BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const { | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | ||||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | VectorRef mul4 = VectorRef({mul4_var_, Zs}); | ||||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); | |||||
| VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_}); | |||||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | ||||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); | VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); | ||||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); | |||||
| VectorRef add3 = VectorRef({prim::kPrimAdd, real_div2, mul4}); | |||||
| return add3; | return add3; | ||||
| } | } | ||||
| @@ -265,14 +265,14 @@ const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const { | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | ||||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | ||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | ||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); | ||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | ||||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | ||||
| VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | ||||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); | |||||
| VectorRef add5 = VectorRef({prim::kPrimAdd, real_div4, mul4}); | |||||
| return add5; | return add5; | ||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -38,8 +38,8 @@ class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass { | |||||
| mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name())); | mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name())); | ||||
| real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | ||||
| real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | ||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); | |||||
| } | } | ||||
| ~LambNextMVWithDecayRule() override = default; | ~LambNextMVWithDecayRule() override = default; | ||||
| @@ -66,7 +66,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto add5 = node->cast<CNodePtr>(); | auto add5 = node->cast<CNodePtr>(); | ||||
| if (AnfAlgo::GetCNodeName(add5) != prim::kPrimTensorAdd->name() || add5->inputs().size() != kAddInputNum) { | |||||
| if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || add5->inputs().size() != kAddInputNum) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto real_div4_anf = add5->input(1); | auto real_div4_anf = add5->input(1); | ||||
| @@ -82,7 +82,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto add4 = add4_anf->cast<CNodePtr>(); | auto add4 = add4_anf->cast<CNodePtr>(); | ||||
| if (AnfAlgo::GetCNodeName(add4) != prim::kPrimTensorAdd->name() || add4->inputs().size() != kAddInputNum) { | |||||
| if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || add4->inputs().size() != kAddInputNum) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto sqrt1_anf = add4->input(1); | auto sqrt1_anf = add4->input(1); | ||||
| @@ -140,17 +140,17 @@ const BaseRef LambNextMVWithDecayV1Rule::DefinePattern() const { | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | ||||
| VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_}); | VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_}); | ||||
| VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); | VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); | ||||
| VectorRef add1({prim::kPrimTensorAdd, mul2, mul3}); | |||||
| VectorRef add1({prim::kPrimAdd, mul2, mul3}); | |||||
| VectorRef real_div1({prim_real_div, add1, input2_}); | VectorRef real_div1({prim_real_div, add1, input2_}); | ||||
| VectorRef add2({prim::kPrimTensorAdd, real_div1, add2_y_}); | |||||
| VectorRef add2({prim::kPrimAdd, real_div1, add2_y_}); | |||||
| VectorRef mul0({prim::kPrimMul, mul0_x_, input4_}); | VectorRef mul0({prim::kPrimMul, mul0_x_, input4_}); | ||||
| VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_}); | VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_}); | ||||
| VectorRef sqrt0({prim_rsqrt, add2}); | VectorRef sqrt0({prim_rsqrt, add2}); | ||||
| VectorRef add0({prim::kPrimTensorAdd, mul0, mul1}); | |||||
| VectorRef add0({prim::kPrimAdd, mul0, mul1}); | |||||
| VectorRef real_div0({prim_real_div, add0, input5_}); | VectorRef real_div0({prim_real_div, add0, input5_}); | ||||
| VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0}); | VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0}); | ||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input6_}); | VectorRef mul4({prim::kPrimMul, mul4_x_, input6_}); | ||||
| VectorRef add3({prim::kPrimTensorAdd, real_div2, mul4}); | |||||
| VectorRef add3({prim::kPrimAdd, real_div2, mul4}); | |||||
| return add3; | return add3; | ||||
| } | } | ||||
| @@ -54,7 +54,7 @@ const BaseRef LambNextRightRule::DefinePattern() const { | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); | VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); | ||||
| VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); | VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); | ||||
| return VectorRef( | return VectorRef( | ||||
| {prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); | |||||
| {prim::kPrimAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); | |||||
| } | } | ||||
| const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| @@ -32,7 +32,7 @@ class LambNextRightRule : public PatternProcessPass { | |||||
| mul3_x_(std::make_shared<Var>()), | mul3_x_(std::make_shared<Var>()), | ||||
| true_div1_recip_(std::make_shared<Var>()), | true_div1_recip_(std::make_shared<Var>()), | ||||
| add2_y_(std::make_shared<Var>()), | add2_y_(std::make_shared<Var>()), | ||||
| add1_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()))) {} | |||||
| add1_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()))) {} | |||||
| ~LambNextRightRule() override = default; | ~LambNextRightRule() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| @@ -58,7 +58,7 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_ | |||||
| const BaseRef MulAddFusion::DefinePattern() const { | const BaseRef MulAddFusion::DefinePattern() const { | ||||
| VarPtr x = std::make_shared<Var>(); | VarPtr x = std::make_shared<Var>(); | ||||
| VarPtr y = std::make_shared<Var>(); | VarPtr y = std::make_shared<Var>(); | ||||
| VectorRef pattern({prim::kPrimTensorAdd, x, y}); | |||||
| VectorRef pattern({prim::kPrimAdd, x, y}); | |||||
| return pattern; | return pattern; | ||||
| } | } | ||||
| @@ -51,13 +51,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||||
| } // namespace | } // namespace | ||||
| const BaseRef AdamFusion::DefinePattern() const { | const BaseRef AdamFusion::DefinePattern() const { | ||||
| VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), | |||||
| VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | |||||
| VectorRef next_m = VectorRef( | |||||
| {prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | |||||
| VectorRef next_v = | VectorRef next_v = | ||||
| VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), | |||||
| VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}), | |||||
| VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | ||||
| VectorRef update = VectorRef( | |||||
| {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | |||||
| VectorRef update = | |||||
| VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | |||||
| VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); | VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); | ||||
| VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); | VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); | ||||
| @@ -51,14 +51,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||||
| } // namespace | } // namespace | ||||
| const BaseRef AdamWeightDecayFusion::DefinePattern() const { | const BaseRef AdamWeightDecayFusion::DefinePattern() const { | ||||
| VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), | |||||
| VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | |||||
| VectorRef next_m = VectorRef( | |||||
| {prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | |||||
| VectorRef next_v = | VectorRef next_v = | ||||
| VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), | |||||
| VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}), | |||||
| VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | ||||
| VectorRef update = VectorRef( | |||||
| {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | |||||
| VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); | |||||
| VectorRef update = | |||||
| VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | |||||
| VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); | |||||
| VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); | VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); | ||||
| VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); | VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); | ||||
| @@ -51,7 +51,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||||
| } // namespace | } // namespace | ||||
| const BaseRef AddReluGradV2Fusion::DefinePattern() const { | const BaseRef AddReluGradV2Fusion::DefinePattern() const { | ||||
| VectorRef relu_grad = VectorRef({prim::kPrimReluGradV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_}), mask_}); | |||||
| VectorRef relu_grad = VectorRef({prim::kPrimReluGradV2, VectorRef({prim::kPrimAdd, x1_, x2_}), mask_}); | |||||
| return relu_grad; | return relu_grad; | ||||
| } | } | ||||
| @@ -51,7 +51,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||||
| } // namespace | } // namespace | ||||
| const BaseRef AddReluV2Fusion::DefinePattern() const { | const BaseRef AddReluV2Fusion::DefinePattern() const { | ||||
| VectorRef relu = VectorRef({prim::kPrimReluV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_})}); | |||||
| VectorRef relu = VectorRef({prim::kPrimReluV2, VectorRef({prim::kPrimAdd, x1_, x2_})}); | |||||
| return relu; | return relu; | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ namespace opt { | |||||
| const BaseRef BatchNormAddReluFusion::DefinePattern() const { | const BaseRef BatchNormAddReluFusion::DefinePattern() const { | ||||
| VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); | VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); | ||||
| VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); | VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); | ||||
| VectorRef tensor_add = VectorRef({prim::kPrimTensorAdd, tuple_get_item, z_}); | |||||
| VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_}); | |||||
| VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); | VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); | ||||
| return relu; | return relu; | ||||
| } | } | ||||
| @@ -42,7 +42,7 @@ const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const Anf | |||||
| MS_EXCEPTION_IF_NULL(B); | MS_EXCEPTION_IF_NULL(B); | ||||
| int64_t num_input = AnfAlgo::GetNodeAttr<int64_t>(node, "n"); | int64_t num_input = AnfAlgo::GetNodeAttr<int64_t>(node, "n"); | ||||
| if (num_input == 2) { | if (num_input == 2) { | ||||
| auto prim = std::make_shared<Primitive>(prim::kPrimTensorAdd->name()); | |||||
| auto prim = std::make_shared<Primitive>(prim::kPrimAdd->name()); | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B}; | std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B}; | ||||
| auto add_new = graph->NewCNode(inputs); | auto add_new = graph->NewCNode(inputs); | ||||
| @@ -47,7 +47,7 @@ AnfNodePtr NewCNodeWithInfo(const AnfNodePtrList &inputs, const AnfNodePtr &ori_ | |||||
| } | } | ||||
| AnfNodePtr SimplifyAdd(const AnfNodePtr &node) { | AnfNodePtr SimplifyAdd(const AnfNodePtr &node) { | ||||
| if (!IsPrimitiveCNode(node, prim::kPrimTensorAdd)) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimAdd)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| PatternNode<AnfNodePtr> x, y, z; | PatternNode<AnfNodePtr> x, y, z; | ||||
| @@ -57,13 +57,13 @@ AnfNodePtr SimplifyAdd(const AnfNodePtr &node) { | |||||
| PConstant<AnfNodePtr> any_const_2(node); | PConstant<AnfNodePtr> any_const_2(node); | ||||
| auto add_distri_lambda = [&node, &x, &y, &any_const]() -> AnfNodePtr { | auto add_distri_lambda = [&node, &x, &y, &any_const]() -> AnfNodePtr { | ||||
| auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node); | |||||
| auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), y.GetNode(node)}, node); | |||||
| auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), node_tmp, any_const.GetNode(node)}, node); | auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), node_tmp, any_const.GetNode(node)}, node); | ||||
| return new_cnode; | return new_cnode; | ||||
| }; | }; | ||||
| auto add_union_lambda = [&node, &x, &any_const, &any_const_2]() -> AnfNodePtr { | auto add_union_lambda = [&node, &x, &any_const, &any_const_2]() -> AnfNodePtr { | ||||
| auto new_rhs = any_const.AddByPatternConst(any_const_2, x.GetNode(node)); | auto new_rhs = any_const.AddByPatternConst(any_const_2, x.GetNode(node)); | ||||
| auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), new_rhs}, node); | |||||
| auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), new_rhs}, node); | |||||
| return new_cnode; | return new_cnode; | ||||
| }; | }; | ||||
| // A + 0 = A | // A + 0 = A | ||||
| @@ -88,7 +88,7 @@ AnfNodePtr SimplifySub(const AnfNodePtr &node) { | |||||
| PConstant<AnfNodePtr> any_const(node); | PConstant<AnfNodePtr> any_const(node); | ||||
| auto sub_toadd_lambda = [&node, &x, &any_const]() -> AnfNodePtr { | auto sub_toadd_lambda = [&node, &x, &any_const]() -> AnfNodePtr { | ||||
| auto new_rhs = any_const.ValueNodeWithOprations(prim::kPrimNeg); | auto new_rhs = any_const.ValueNodeWithOprations(prim::kPrimNeg); | ||||
| auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), new_rhs}, node); | |||||
| auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), new_rhs}, node); | |||||
| return new_cnode; | return new_cnode; | ||||
| }; | }; | ||||
| // A - 0 = A | // A - 0 = A | ||||
| @@ -269,7 +269,7 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) { | |||||
| return new_cnode; | return new_cnode; | ||||
| }; | }; | ||||
| auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr { | auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr { | ||||
| auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node); | |||||
| auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), y.GetNode(node)}, node); | |||||
| auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node); | auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node); | ||||
| return new_cnode; | return new_cnode; | ||||
| }; | }; | ||||
| @@ -741,14 +741,14 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p | |||||
| std::vector<PrimitivePtr> GetFusibleOpList() { | std::vector<PrimitivePtr> GetFusibleOpList() { | ||||
| #if ENABLE_D | #if ENABLE_D | ||||
| std::vector<PrimitivePtr> fusible_basic_ops = { | std::vector<PrimitivePtr> fusible_basic_ops = { | ||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | |||||
| prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | ||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | ||||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | ||||
| prim::kPrimCast, prim::kPrimRealDiv}; | prim::kPrimCast, prim::kPrimRealDiv}; | ||||
| #elif ENABLE_GPU | #elif ENABLE_GPU | ||||
| std::vector<PrimitivePtr> fusible_basic_ops = { | std::vector<PrimitivePtr> fusible_basic_ops = { | ||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | |||||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | ||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | ||||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, | prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, | ||||
| @@ -52,7 +52,7 @@ namespace opt { | |||||
| namespace irpass { | namespace irpass { | ||||
| OptimizeIRPassLib::OptimizeIRPassLib() { | OptimizeIRPassLib::OptimizeIRPassLib() { | ||||
| arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify", | arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify", | ||||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | |||||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimAdd, | |||||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); | prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); | ||||
| arithmetic_simplify2_ = | arithmetic_simplify2_ = | ||||
| MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul}); | MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul}); | ||||
| @@ -272,7 +272,7 @@ class AddNEliminater : public AnfVisitor { | |||||
| if (tuple_inputs.size() == 3) { | if (tuple_inputs.size() == 3) { | ||||
| // case2: inputs size = 2, -> TensorAdd(Tensor, Tensor) | // case2: inputs size = 2, -> TensorAdd(Tensor, Tensor) | ||||
| MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2); | MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2); | ||||
| ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); | |||||
| ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations"); | |||||
| std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1], | std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1], | ||||
| tuple_inputs[2]}; | tuple_inputs[2]}; | ||||
| mng->Replace(node, func_graph->NewCNode(new_xs)); | mng->Replace(node, func_graph->NewCNode(new_xs)); | ||||
| @@ -299,7 +299,7 @@ class AddNEliminater : public AnfVisitor { | |||||
| ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations"); | ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations"); | ||||
| auto new_addn = func_graph->NewCNode( | auto new_addn = func_graph->NewCNode( | ||||
| {func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)}); | {func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)}); | ||||
| ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); | |||||
| ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations"); | |||||
| auto new_add = | auto new_add = | ||||
| func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn}); | func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn}); | ||||
| (void)mng->Replace(node, new_add); | (void)mng->Replace(node, new_add); | ||||
| @@ -860,7 +860,7 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera | |||||
| if (ops[iter_ops]->type() == L2_NORMALIZE) { | if (ops[iter_ops]->type() == L2_NORMALIZE) { | ||||
| return PrepareL2Normalize(ops, iter_ops, basic_stra); | return PrepareL2Normalize(ops, iter_ops, basic_stra); | ||||
| } | } | ||||
| if (ops[iter_ops]->type() == TENSOR_ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL || | |||||
| if (ops[iter_ops]->type() == ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL || | |||||
| ops[iter_ops]->type() == DIV) { | ops[iter_ops]->type() == DIV) { | ||||
| return CheckBroadcast(ops, iter_ops, basic_stra); | return CheckBroadcast(ops, iter_ops, basic_stra); | ||||
| } | } | ||||
| @@ -78,7 +78,7 @@ const std::map<std::string, OperatorType> DictOpType{ | |||||
| // Elm-wise OP | // Elm-wise OP | ||||
| {TRANSPOSE, OperatorType::kRecElmWiseOp}, | {TRANSPOSE, OperatorType::kRecElmWiseOp}, | ||||
| {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, | {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, | ||||
| {TENSOR_ADD, OperatorType::kRecElmWiseOp}, | |||||
| {ADD, OperatorType::kRecElmWiseOp}, | |||||
| {TENSOR_DOT, OperatorType::kRecElmWiseOp}, | {TENSOR_DOT, OperatorType::kRecElmWiseOp}, | ||||
| {SUB, OperatorType::kRecElmWiseOp}, | {SUB, OperatorType::kRecElmWiseOp}, | ||||
| {MUL, OperatorType::kRecElmWiseOp}, | {MUL, OperatorType::kRecElmWiseOp}, | ||||
| @@ -86,7 +86,7 @@ REGISTER(LogSoftmaxInfo); | |||||
| REGISTER(ActivationInfo); | REGISTER(ActivationInfo); | ||||
| REGISTER(SoftmaxCrossEntropyWithLogitsInfo); | REGISTER(SoftmaxCrossEntropyWithLogitsInfo); | ||||
| REGISTER(SubInfo); | REGISTER(SubInfo); | ||||
| REGISTER(TensorAddInfo); | |||||
| REGISTER(AddInfo); | |||||
| REGISTER(BiasAddInfo); | REGISTER(BiasAddInfo); | ||||
| REGISTER(MulInfo); | REGISTER(MulInfo); | ||||
| REGISTER(DivInfo); | REGISTER(DivInfo); | ||||
| @@ -60,12 +60,11 @@ class SubInfo : public ArithmeticBase { | |||||
| ~SubInfo() override = default; | ~SubInfo() override = default; | ||||
| }; | }; | ||||
| class TensorAddInfo : public ArithmeticBase { | |||||
| class AddInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||||
| const PrimitiveAttrs &attrs) | |||||
| AddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {} | : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {} | ||||
| ~TensorAddInfo() override = default; | |||||
| ~AddInfo() override = default; | |||||
| }; | }; | ||||
| class MulInfo : public ArithmeticBase { | class MulInfo : public ArithmeticBase { | ||||
| @@ -191,7 +191,7 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||||
| auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)}); | auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)}); | ||||
| auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)}); | auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)}); | ||||
| auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast}); | auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast}); | ||||
| auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(TENSOR_ADD), mul2, CreateInt32Tensor(1)}); | |||||
| auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(ADD), mul2, CreateInt32Tensor(1)}); | |||||
| auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add}); | auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add}); | ||||
| auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)}); | auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)}); | ||||
| Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_); | Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_); | ||||
| @@ -200,7 +200,7 @@ constexpr char MAXPOOLV2[] = "MaxPoolV2"; | |||||
| constexpr char L2_NORMALIZE[] = "L2Normalize"; | constexpr char L2_NORMALIZE[] = "L2Normalize"; | ||||
| constexpr char TRANSPOSE[] = "Transpose"; | constexpr char TRANSPOSE[] = "Transpose"; | ||||
| constexpr char RESHAPE[] = "Reshape"; | constexpr char RESHAPE[] = "Reshape"; | ||||
| constexpr char TENSOR_ADD[] = "TensorAdd"; | |||||
| constexpr char ADD[] = "Add"; | |||||
| constexpr char BIAS_ADD[] = "BiasAdd"; | constexpr char BIAS_ADD[] = "BiasAdd"; | ||||
| constexpr char SUB[] = "Sub"; | constexpr char SUB[] = "Sub"; | ||||
| constexpr char MUL[] = "Mul"; | constexpr char MUL[] = "Mul"; | ||||
| @@ -315,7 +315,6 @@ constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin"; | |||||
| constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax"; | constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax"; | ||||
| constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; | constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; | ||||
| constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; | constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; | ||||
| constexpr char ADD[] = "Add"; | |||||
| constexpr char DROPOUT[] = "Dropout"; | constexpr char DROPOUT[] = "Dropout"; | ||||
| constexpr char KStridedSlice[] = "StridedSlice"; | constexpr char KStridedSlice[] = "StridedSlice"; | ||||
| constexpr char UNIQUE[] = "Unique"; | constexpr char UNIQUE[] = "Unique"; | ||||
| @@ -151,7 +151,7 @@ bool IsSplittableOperator(const std::string &op_name) { | |||||
| // clang-format off | // clang-format off | ||||
| static const std::set<std::string> splittable_op = | static const std::set<std::string> splittable_op = | ||||
| {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, | {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, | ||||
| FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, | |||||
| FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, | |||||
| REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, | REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, | ||||
| MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, PACK, | MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, PACK, | ||||
| LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, | LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, | ||||
| @@ -165,7 +165,7 @@ class OpNameInfo { | |||||
| #define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \ | #define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \ | ||||
| OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); } | OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); } | ||||
| OPERATOR_ONNX_CONVERT_DEFINE(TensorAdd, Add, OpNameInfo()) | |||||
| OPERATOR_ONNX_CONVERT_DEFINE(Add, Add, OpNameInfo()) | |||||
| OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo()) | OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo()) | ||||
| OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo()) | OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo()) | ||||
| @@ -257,7 +257,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo()) | |||||
| #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name | #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name | ||||
| void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) { | void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) { | ||||
| fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); | |||||
| fn(OP_CONVERT_FUNCTION_NAME(Add)()); | |||||
| fn(OP_CONVERT_FUNCTION_NAME(Mul)()); | fn(OP_CONVERT_FUNCTION_NAME(Mul)()); | ||||
| fn(OP_CONVERT_FUNCTION_NAME(ReLU)()); | fn(OP_CONVERT_FUNCTION_NAME(ReLU)()); | ||||
| @@ -29,7 +29,7 @@ REG_ADPT_DESC(StateSetItem, prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)) | |||||
| INPUT_MAP(Add) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | INPUT_MAP(Add) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | ||||
| ATTR_MAP(Add) = EMPTY_ATTR_MAP; | ATTR_MAP(Add) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(Add) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(Add) = {{0, OUTPUT_DESC(y)}}; | ||||
| REG_ADPT_DESC(Add, prim::kPrimTensorAdd->name(), | |||||
| REG_ADPT_DESC(Add, prim::kPrimAdd->name(), | |||||
| std::make_shared<OpAdapterDesc>( | std::make_shared<OpAdapterDesc>( | ||||
| std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})), | std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})), | ||||
| std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})))) | std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})))) | ||||
| @@ -215,7 +215,7 @@ constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | |||||
| constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; | constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; | ||||
| constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax"; | constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax"; | ||||
| constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax"; | constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax"; | ||||
| constexpr auto kTensorAddOpName = "TensorAdd"; | |||||
| constexpr auto kTensorAddOpName = "Add"; | |||||
| constexpr auto kCastOpName = "Cast"; | constexpr auto kCastOpName = "Cast"; | ||||
| constexpr auto kGreaterEqualOpName = "GreaterEqual"; | constexpr auto kGreaterEqualOpName = "GreaterEqual"; | ||||
| constexpr auto kAbsOpName = "Abs"; | constexpr auto kAbsOpName = "Abs"; | ||||
| @@ -46,7 +46,7 @@ class ExportToQuantInferNetwork: | |||||
| Returns: | Returns: | ||||
| Cell, Infer network. | Cell, Infer network. | ||||
| """ | """ | ||||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||||
| __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"] | |||||
| def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): | def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): | ||||
| network = Validator.check_isinstance('network', network, (nn.Cell,)) | network = Validator.check_isinstance('network', network, (nn.Cell,)) | ||||
| @@ -225,7 +225,7 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork): | |||||
| Returns: | Returns: | ||||
| Cell, Infer network. | Cell, Infer network. | ||||
| """ | """ | ||||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||||
| __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"] | |||||
| def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): | def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): | ||||
| super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir) | super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir) | ||||
| @@ -173,7 +173,7 @@ class QuantizationAwareTraining(Quantizer): | |||||
| >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) | >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) | ||||
| >>> net_qat = quantizer.quantize(net) | >>> net_qat = quantizer.quantize(net) | ||||
| """ | """ | ||||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||||
| __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"] | |||||
| def __init__(self, | def __init__(self, | ||||
| bn_fold=True, | bn_fold=True, | ||||
| @@ -91,8 +91,8 @@ AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const Primitive | |||||
| AbstractBasePtr InferImplMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| @@ -60,8 +60,8 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| return out->Broaden(); | return out->Broaden(); | ||||
| } | } | ||||
| AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two tensors. | // Inputs: two tensors. | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| CheckArgsSize(op_name, args_spec_list, 2); | CheckArgsSize(op_name, args_spec_list, 2); | ||||
| @@ -37,7 +37,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | ||||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | ||||
| {prim::kPrimMul, {InferImplMul, true}}, | {prim::kPrimMul, {InferImplMul, true}}, | ||||
| {prim::kPrimTensorAdd, {InferImplTensorAdd, true}}, | |||||
| {prim::kPrimAdd, {InferImplAdd, true}}, | |||||
| {prim::kPrimSquare, {InferImplSquare, true}}, | {prim::kPrimSquare, {InferImplSquare, true}}, | ||||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | {prim::kPrimSqrt, {InferImplSqrt, true}}, | ||||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | ||||
| @@ -236,7 +236,7 @@ inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primiti | |||||
| inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | ||||
| // Maths | // Maths | ||||
| inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); | |||||
| inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>("Add"); | |||||
| inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); | inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); | ||||
| inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); | inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); | ||||
| inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); | inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); | ||||
| @@ -49,6 +49,6 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | ||||
| InferShape(primitive, input_args)->shape()); | InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimTensorAdd, AddInfer); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAdd, Add); | REGISTER_PRIMITIVE_C(kNameAdd, Add); | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -989,7 +989,7 @@ class PConstant : public PBase<PConstant<T> > { | |||||
| } | } | ||||
| // Arithmetic operations | // Arithmetic operations | ||||
| BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true); | |||||
| BIN_OPERATION_PATTERN(operator+, prim::kPrimAdd, true); | |||||
| BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true); | BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true); | ||||
| BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false); | BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false); | ||||
| BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false); | BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false); | ||||
| @@ -225,7 +225,7 @@ class LambNextMV(GraphKernel): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(LambNextMV, self).__init__() | super(LambNextMV, self).__init__() | ||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.div = P.RealDiv() | self.div = P.RealDiv() | ||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.rsqrt = P.Rsqrt() | self.rsqrt = P.Rsqrt() | ||||
| @@ -651,7 +651,7 @@ class LogSigmoid(Cell): | |||||
| super(LogSigmoid, self).__init__() | super(LogSigmoid, self).__init__() | ||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.rec = P.Reciprocal() | self.rec = P.Reciprocal() | ||||
| self.log = P.Log() | self.log = P.Log() | ||||
| @@ -441,13 +441,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): | |||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.inf_mask_mul = P.Mul() | self.inf_mask_mul = P.Mul() | ||||
| self.bias_add = P.TensorAdd() | |||||
| self.inf_add = P.TensorAdd() | |||||
| self.bias_add = P.Add() | |||||
| self.inf_add = P.Add() | |||||
| self.merge_op = None | self.merge_op = None | ||||
| self.count_op = P.UnsortedSegmentSum() | self.count_op = P.UnsortedSegmentSum() | ||||
| self.abs = P.Abs() | self.abs = P.Abs() | ||||
| self.equal = P.Equal() | self.equal = P.Equal() | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.div_no_nan = P.DivNoNan() | self.div_no_nan = P.DivNoNan() | ||||
| self.expand = P.ExpandDims() | self.expand = P.ExpandDims() | ||||
| @@ -99,8 +99,8 @@ class BatchNormFoldCell(Cell): | |||||
| else: | else: | ||||
| batch_mean = P.ZerosLike()(variance) | batch_mean = P.ZerosLike()(variance) | ||||
| batch_std = P.OnesLike()(variance) | batch_std = P.OnesLike()(variance) | ||||
| running_mean = P.TensorAdd()(mean, 0.) | |||||
| running_std = P.Sqrt()(P.TensorAdd()(variance, self.epsilon)) | |||||
| running_mean = P.Add()(mean, 0.) | |||||
| running_std = P.Sqrt()(P.Add()(variance, self.epsilon)) | |||||
| return batch_mean, batch_std, running_mean, running_std | return batch_mean, batch_std, running_mean, running_std | ||||
| @@ -559,7 +559,7 @@ class Conv2dBnFoldQuantOneConv(Cell): | |||||
| return s | return s | ||||
| def construct(self, x): | def construct(self, x): | ||||
| running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps)) | |||||
| running_std = P.Sqrt()(P.Add()(self.moving_variance, self.eps)) | |||||
| scale_factor = self.gamma / running_std | scale_factor = self.gamma / running_std | ||||
| if self.channel_axis: | if self.channel_axis: | ||||
| scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) | scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) | ||||
| @@ -1236,7 +1236,7 @@ class TensorAddQuant(Cell): | |||||
| ema=True, | ema=True, | ||||
| ema_decay=ema_decay, | ema_decay=ema_decay, | ||||
| quant_dtype=quant_dtype) | quant_dtype=quant_dtype) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x1, x2): | def construct(self, x1, x2): | ||||
| x = self.add(x1, x2) | x = self.add(x1, x2) | ||||
| @@ -155,9 +155,9 @@ def bprop_batchmatmul(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.TensorAdd) | |||||
| @bprop_getters.register(P.Add) | |||||
| def get_bprop_tensor_add(self): | def get_bprop_tensor_add(self): | ||||
| """Grad definition for `TensorAdd` operation.""" | |||||
| """Grad definition for `Add` operation.""" | |||||
| def bprop(x, y, out, dout): | def bprop(x, y, out, dout): | ||||
| return binop_grad_common(x, y, dout, dout) | return binop_grad_common(x, y, dout, dout) | ||||
| @@ -13,10 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """TensorAdd op""" | |||||
| """Add op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | ||||
| op_info = AkgAscendRegOp("TensorAdd") \ | |||||
| op_info = AkgAscendRegOp("Add") \ | |||||
| .fusion_type("ELEMWISE") \ | .fusion_type("ELEMWISE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .input(1, "y") \ | .input(1, "y") \ | ||||
| @@ -38,5 +38,5 @@ op_info = AkgAscendRegOp("TensorAdd") \ | |||||
| @op_info_register(op_info) | @op_info_register(op_info) | ||||
| def _add_akg(): | def _add_akg(): | ||||
| """TensorAdd Akg register""" | |||||
| """Add Akg register""" | |||||
| return | return | ||||
| @@ -13,10 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """TensorAdd op""" | |||||
| """Add op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| tensor_add_op_info = TBERegOp("TensorAdd") \ | |||||
| tensor_add_op_info = TBERegOp("Add") \ | |||||
| .fusion_type("ELEMWISE") \ | .fusion_type("ELEMWISE") \ | ||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("add.so") \ | .binfile_name("add.so") \ | ||||
| @@ -16,7 +16,7 @@ | |||||
| """TensorAdd op""" | """TensorAdd op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| tensor_add_op_info = TBERegOp("TensorAdd") \ | |||||
| tensor_add_op_info = TBERegOp("Add") \ | |||||
| .fusion_type("ELEMWISE") \ | .fusion_type("ELEMWISE") \ | ||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("add.so") \ | .binfile_name("add.so") \ | ||||
| @@ -395,7 +395,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): | |||||
| >>> from mindspore.ops import Primitive, operations as P | >>> from mindspore.ops import Primitive, operations as P | ||||
| >>> from mindspore import dtype as mstype | >>> from mindspore import dtype as mstype | ||||
| >>> | >>> | ||||
| >>> tensor_add = P.TensorAdd() | |||||
| >>> tensor_add = P.Add() | |||||
| >>> add = MultitypeFuncGraph('add') | >>> add = MultitypeFuncGraph('add') | ||||
| >>> @add.register("Number", "Number") | >>> @add.register("Number", "Number") | ||||
| ... def add_scala(x, y): | ... def add_scala(x, y): | ||||
| @@ -51,7 +51,7 @@ merge = P.Merge() | |||||
| geswitch = P.GeSwitch() | geswitch = P.GeSwitch() | ||||
| addn = P.AddN() | addn = P.AddN() | ||||
| absolute = P.Abs() | absolute = P.Abs() | ||||
| tensor_add = P.TensorAdd() | |||||
| tensor_add = P.Add() | |||||
| neg_tensor = P.Neg() | neg_tensor = P.Neg() | ||||
| tensor_lt = P.Less() | tensor_lt = P.Less() | ||||
| tensor_le = P.LessEqual() | tensor_le = P.LessEqual() | ||||
| @@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||||
| NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | ||||
| Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | ||||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | ||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, | |||||
| Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, | |||||
| MatrixInverse) | MatrixInverse) | ||||
| from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | ||||
| @@ -102,6 +102,7 @@ __all__ = [ | |||||
| 'Sort', | 'Sort', | ||||
| 'EditDistance', | 'EditDistance', | ||||
| 'CropAndResize', | 'CropAndResize', | ||||
| 'Add', | |||||
| 'TensorAdd', | 'TensorAdd', | ||||
| 'Argmax', | 'Argmax', | ||||
| 'Argmin', | 'Argmin', | ||||
| @@ -106,7 +106,7 @@ class GeSwitch(PrimitiveWithInfer): | |||||
| ... def __init__(self): | ... def __init__(self): | ||||
| ... super(Net, self).__init__() | ... super(Net, self).__init__() | ||||
| ... self.square = ops.Square() | ... self.square = ops.Square() | ||||
| ... self.add = ops.TensorAdd() | |||||
| ... self.add = ops.Add() | |||||
| ... self.value = Tensor(np.full((1), 3), mindspore.float32) | ... self.value = Tensor(np.full((1), 3), mindspore.float32) | ||||
| ... self.switch = ops.GeSwitch() | ... self.switch = ops.GeSwitch() | ||||
| ... self.merge = ops.Merge() | ... self.merge = ops.Merge() | ||||
| @@ -66,7 +66,7 @@ class ScalarSummary(PrimitiveWithInfer): | |||||
| ... def __init__(self,): | ... def __init__(self,): | ||||
| ... super(SummaryDemo, self).__init__() | ... super(SummaryDemo, self).__init__() | ||||
| ... self.summary = ops.ScalarSummary() | ... self.summary = ops.ScalarSummary() | ||||
| ... self.add = ops.TensorAdd() | |||||
| ... self.add = ops.Add() | |||||
| ... | ... | ||||
| ... def construct(self, x, y): | ... def construct(self, x, y): | ||||
| ... name = "x" | ... name = "x" | ||||
| @@ -149,7 +149,7 @@ class TensorSummary(PrimitiveWithInfer): | |||||
| ... def __init__(self,): | ... def __init__(self,): | ||||
| ... super(SummaryDemo, self).__init__() | ... super(SummaryDemo, self).__init__() | ||||
| ... self.summary = ops.TensorSummary() | ... self.summary = ops.TensorSummary() | ||||
| ... self.add = ops.TensorAdd() | |||||
| ... self.add = ops.Add() | |||||
| ... | ... | ||||
| ... def construct(self, x, y): | ... def construct(self, x, y): | ||||
| ... x = self.add(x, y) | ... x = self.add(x, y) | ||||
| @@ -191,7 +191,7 @@ class HistogramSummary(PrimitiveWithInfer): | |||||
| ... def __init__(self,): | ... def __init__(self,): | ||||
| ... super(SummaryDemo, self).__init__() | ... super(SummaryDemo, self).__init__() | ||||
| ... self.summary = ops.HistogramSummary() | ... self.summary = ops.HistogramSummary() | ||||
| ... self.add = ops.TensorAdd() | |||||
| ... self.add = ops.Add() | |||||
| ... | ... | ||||
| ... def construct(self, x, y): | ... def construct(self, x, y): | ||||
| ... x = self.add(x, y) | ... x = self.add(x, y) | ||||
| @@ -409,7 +409,7 @@ class Assert(PrimitiveWithInfer): | |||||
| ... def __init__(self): | ... def __init__(self): | ||||
| ... super(AssertDemo, self).__init__() | ... super(AssertDemo, self).__init__() | ||||
| ... self.assert1 = ops.Assert(summarize=10) | ... self.assert1 = ops.Assert(summarize=10) | ||||
| ... self.add = ops.TensorAdd() | |||||
| ... self.add = ops.Add() | |||||
| ... | ... | ||||
| ... def construct(self, x, y): | ... def construct(self, x, y): | ||||
| ... data = self.add(x, y) | ... data = self.add(x, y) | ||||
| @@ -18,6 +18,7 @@ | |||||
| import copy | import copy | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import log as logger | |||||
| from ... import context | from ... import context | ||||
| from .. import signature as sig | from .. import signature as sig | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| @@ -114,7 +115,7 @@ class _BitwiseBinaryOp(_MathBinaryOp): | |||||
| return _BitwiseBinaryOp._check_bitwise_op_input_type(x1_type, x2_type, self.name) | return _BitwiseBinaryOp._check_bitwise_op_input_type(x1_type, x2_type, self.name) | ||||
| class TensorAdd(_MathBinaryOp): | |||||
| class Add(_MathBinaryOp): | |||||
| r""" | r""" | ||||
| Adds two input tensors element-wise. | Adds two input tensors element-wise. | ||||
| @@ -143,7 +144,7 @@ class TensorAdd(_MathBinaryOp): | |||||
| ``Ascend`` ``GPU`` ``CPU`` | ``Ascend`` ``GPU`` ``CPU`` | ||||
| Examples: | Examples: | ||||
| >>> add = ops.TensorAdd() | |||||
| >>> add = ops.Add() | |||||
| >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) | >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) | ||||
| >>> input_y = Tensor(np.array([4, 5, 6]).astype(np.float32)) | >>> input_y = Tensor(np.array([4, 5, 6]).astype(np.float32)) | ||||
| >>> output = add(input_x, input_y) | >>> output = add(input_x, input_y) | ||||
| @@ -160,6 +161,10 @@ class TensorAdd(_MathBinaryOp): | |||||
| return Tensor(out) | return Tensor(out) | ||||
| return None | return None | ||||
| def TensorAdd(): | |||||
| """Warning: This will be changed later""" | |||||
| logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.") | |||||
| return Add() | |||||
| class AssignAdd(PrimitiveWithInfer): | class AssignAdd(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -16,7 +16,7 @@ | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import TensorAdd | |||||
| from mindspore.ops.operations import Add | |||||
| from src.var_init import KaimingNormal | from src.var_init import KaimingNormal | ||||
| @@ -91,7 +91,7 @@ class InvertedResidual(nn.Cell): | |||||
| ]) | ]) | ||||
| self.conv = nn.SequentialCell(layers) | self.conv = nn.SequentialCell(layers) | ||||
| self.add = TensorAdd() | |||||
| self.add = Add() | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -198,7 +198,7 @@ class BasicBlock(nn.Cell): | |||||
| self.bn2 = ms_fused_bn(planes) | self.bn2 = ms_fused_bn(planes) | ||||
| self.relu = P.ReLU() | self.relu = P.ReLU() | ||||
| self.downsample = downsample | self.downsample = downsample | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| residual = x | residual = x | ||||
| @@ -102,7 +102,7 @@ class Bottleneck(nn.Cell): | |||||
| self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||||
| self.downsample = downsample | self.downsample = downsample | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -222,7 +222,7 @@ class ResidualBlockUsing(nn.Cell): | |||||
| self.bn_down_sample = self.bn_down_sample.set_train() | self.bn_down_sample = self.bn_down_sample.set_train() | ||||
| if not weights_update: | if not weights_update: | ||||
| self.conv_down_sample.weight.requires_grad = False | self.conv_down_sample.weight.requires_grad = False | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -218,7 +218,7 @@ class ResidualBlockUsing(nn.Cell): | |||||
| self.bn_down_sample = self.bn_down_sample.set_train() | self.bn_down_sample = self.bn_down_sample.set_train() | ||||
| if not weights_update: | if not weights_update: | ||||
| self.conv_down_sample.weight.requires_grad = False | self.conv_down_sample.weight.requires_grad = False | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -16,7 +16,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import TensorAdd | |||||
| from mindspore.ops.operations import Add | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| __all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2'] | __all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2'] | ||||
| @@ -129,7 +129,7 @@ class InvertedResidual(nn.Cell): | |||||
| nn.BatchNorm2d(oup), | nn.BatchNorm2d(oup), | ||||
| ]) | ]) | ||||
| self.conv = nn.SequentialCell(layers) | self.conv = nn.SequentialCell(layers) | ||||
| self.add = TensorAdd() | |||||
| self.add = Add() | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -120,7 +120,7 @@ class InvertedResidual(nn.Cell): | |||||
| nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) | nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) | ||||
| ]) | ]) | ||||
| self.conv = nn.SequentialCell(layers) | self.conv = nn.SequentialCell(layers) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| out = self.conv(x) | out = self.conv(x) | ||||
| @@ -123,7 +123,7 @@ class InvertedResidual(nn.Cell): | |||||
| nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) | nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) | ||||
| ]) | ]) | ||||
| self.conv = nn.SequentialCell(layers) | self.conv = nn.SequentialCell(layers) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| out = self.conv(x) | out = self.conv(x) | ||||
| @@ -197,7 +197,7 @@ class ResUnit(nn.Cell): | |||||
| padding=0, act_type=act_type, use_act=False) | padding=0, act_type=act_type, use_act=False) | ||||
| if num_in != num_out or stride != 1: | if num_in != num_out or stride != 1: | ||||
| self.use_short_cut_conv = False | self.use_short_cut_conv = False | ||||
| self.add = P.TensorAdd() if self.use_short_cut_conv else None | |||||
| self.add = P.Add() if self.use_short_cut_conv else None | |||||
| def construct(self, x): | def construct(self, x): | ||||
| """construct""" | """construct""" | ||||
| @@ -49,7 +49,7 @@ class DiceLoss(_Loss): | |||||
| self.logical_or = P.LogicalOr() | self.logical_or = P.LogicalOr() | ||||
| self.equal = P.Equal() | self.equal = P.Equal() | ||||
| self.zeros_like = P.ZerosLike() | self.zeros_like = P.ZerosLike() | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.gather = P.Gather() | self.gather = P.Gather() | ||||
| def ohem_batch(self, scores, gt_texts, training_masks): | def ohem_batch(self, scores, gt_texts, training_masks): | ||||
| @@ -61,7 +61,7 @@ class ResidualBlock(nn.Cell): | |||||
| kernel_size=1, stride=stride) | kernel_size=1, stride=stride) | ||||
| self.bn_down_sample = _bn(out_channels, momentum=momentum) | self.bn_down_sample = _bn(out_channels, momentum=momentum) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -152,7 +152,7 @@ class ResidualBlock(nn.Cell): | |||||
| else: | else: | ||||
| self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, | self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, | ||||
| use_se=self.use_se), _bn(out_channel)]) | use_se=self.use_se), _bn(out_channel)]) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -119,7 +119,7 @@ class ResidualBlock(nn.Cell): | |||||
| if self.down_sample: | if self.down_sample: | ||||
| self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)]) | self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)]) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -85,7 +85,7 @@ class ResidualBlock(nn.Cell): | |||||
| self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel, | self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel, | ||||
| kernel_size=1, stride=stride, | kernel_size=1, stride=stride, | ||||
| pad_mode='same', padding=0, has_bn=True, activation='relu') | pad_mode='same', padding=0, has_bn=True, activation='relu') | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.relu = P.ReLU() | self.relu = P.ReLU() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -215,7 +215,7 @@ class ResidualBlock(nn.Cell): | |||||
| frequency=frequency, | frequency=frequency, | ||||
| batch_size=batch_size), | batch_size=batch_size), | ||||
| _bn(out_channel)]) | _bn(out_channel)]) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -333,7 +333,7 @@ class Dense_Thor_GPU(Cell): | |||||
| self.gather = P.Gather() | self.gather = P.Gather() | ||||
| self.freq = Tensor(frequency, mstype.int32) | self.freq = Tensor(frequency, mstype.int32) | ||||
| self.axis = 0 | self.axis = 0 | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.cholesky = P.CholeskyTrsm(split_dim=split_dim) | self.cholesky = P.CholeskyTrsm(split_dim=split_dim) | ||||
| self.vector_matmul = P.BatchMatMul(transpose_a=True) | self.vector_matmul = P.BatchMatMul(transpose_a=True) | ||||
| @@ -690,7 +690,7 @@ class Dense_Thor(Cell): | |||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| self.dampingA = Tensor(np.identity(2048), mstype.float32) | self.dampingA = Tensor(np.identity(2048), mstype.float32) | ||||
| self.dampingG = Tensor(np.identity(1024), mstype.float32) | self.dampingG = Tensor(np.identity(1024), mstype.float32) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.getG = P.InsertGradientOf(self.save_gradient) | self.getG = P.InsertGradientOf(self.save_gradient) | ||||
| @@ -16,7 +16,7 @@ | |||||
| ResNet based ResNext | ResNet based ResNext | ||||
| """ | """ | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.ops.operations import TensorAdd, Split, Concat | |||||
| from mindspore.ops.operations import Add, Split, Concat | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.initializer import TruncatedNormal | from mindspore.common.initializer import TruncatedNormal | ||||
| @@ -105,7 +105,7 @@ class BasicBlock(nn.Cell): | |||||
| self.down_sample = down_sample | self.down_sample = down_sample | ||||
| self.down_sample_flag = True | self.down_sample_flag = True | ||||
| self.add = TensorAdd() | |||||
| self.add = Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -176,7 +176,7 @@ class Bottleneck(nn.Cell): | |||||
| self.down_sample_flag = True | self.down_sample_flag = True | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.add = TensorAdd() | |||||
| self.add = Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -95,7 +95,7 @@ class ResidualBlock(nn.Cell): | |||||
| if self.down_sample: | if self.down_sample: | ||||
| self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), | self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), | ||||
| _bn(out_channel)]) | _bn(out_channel)]) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -68,7 +68,7 @@ class ShuffleV1Block(nn.Cell): | |||||
| outputs = oup | outputs = oup | ||||
| self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.concat = P.Concat(1) | self.concat = P.Concat(1) | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.transpose = P.Transpose() | self.transpose = P.Transpose() | ||||
| @@ -170,7 +170,7 @@ class SqueezeNet_Residual(nn.Cell): | |||||
| self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||||
| self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2) | self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.dropout = nn.Dropout(keep_prob=0.5) | self.dropout = nn.Dropout(keep_prob=0.5) | ||||
| self.mean = P.ReduceMean(keep_dims=True) | self.mean = P.ReduceMean(keep_dims=True) | ||||
| self.flatten = nn.Flatten() | self.flatten = nn.Flatten() | ||||
| @@ -133,7 +133,7 @@ class InvertedResidual(nn.Cell): | |||||
| _bn(oup), | _bn(oup), | ||||
| ]) | ]) | ||||
| self.conv = nn.SequentialCell(layers) | self.conv = nn.SequentialCell(layers) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.last_relu = last_relu | self.last_relu = last_relu | ||||
| self.relu = nn.ReLU6() | self.relu = nn.ReLU6() | ||||
| @@ -68,7 +68,7 @@ class Block(nn.Cell): | |||||
| if strides != 1: | if strides != 1: | ||||
| rep.append(nn.MaxPool2d(3, strides, pad_mode="same")) | rep.append(nn.MaxPool2d(3, strides, pad_mode="same")) | ||||
| self.rep = nn.SequentialCell(*rep) | self.rep = nn.SequentialCell(*rep) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, inp): | def construct(self, inp): | ||||
| x = self.rep(inp) | x = self.rep(inp) | ||||
| @@ -62,7 +62,7 @@ class ResidualBlock(nn.Cell): | |||||
| out_chls = out_channels//2 | out_chls = out_channels//2 | ||||
| self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) | self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) | ||||
| self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) | self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -59,7 +59,7 @@ class ResidualBlock(nn.Cell): | |||||
| out_chls = out_channels//2 | out_chls = out_channels//2 | ||||
| self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) | self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) | ||||
| self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) | self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -107,7 +107,7 @@ class BasicBlock(nn.Cell): | |||||
| self.downsample = (in_channels != out_channels) | self.downsample = (in_channels != out_channels) | ||||
| if self.downsample: | if self.downsample: | ||||
| self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride) | self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -76,7 +76,7 @@ class ResidualBlock(nn.Cell): | |||||
| out_chls = out_channels | out_chls = out_channels | ||||
| self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) | self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) | ||||
| self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) | self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| @@ -111,7 +111,7 @@ class CspDarkNet53(nn.Cell): | |||||
| self.outchannel = 1024 | self.outchannel = 1024 | ||||
| self.detect = detect | self.detect = detect | ||||
| self.concat = P.Concat(axis=1) | self.concat = P.Concat(axis=1) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.conv0 = conv_block(3, 32, kernel_size=3, stride=1) | self.conv0 = conv_block(3, 32, kernel_size=3, stride=1) | ||||
| self.conv1 = conv_block(32, 64, kernel_size=3, stride=2) | self.conv1 = conv_block(32, 64, kernel_size=3, stride=2) | ||||
| @@ -188,7 +188,7 @@ class EmbeddingPostprocessor(nn.Cell): | |||||
| use_one_hot=False) | use_one_hot=False) | ||||
| self.layernorm = nn.LayerNorm((embedding_size,)) | self.layernorm = nn.LayerNorm((embedding_size,)) | ||||
| self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) | self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, token_type_ids, word_embeddings): | def construct(self, token_type_ids, word_embeddings): | ||||
| """Postprocessors apply positional and token type embeddings to word embeddings.""" | """Postprocessors apply positional and token type embeddings to word embeddings.""" | ||||
| @@ -226,7 +226,7 @@ class BertOutput(nn.Cell): | |||||
| weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) | weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) | ||||
| self.dropout = nn.Dropout(1 - dropout_prob) | self.dropout = nn.Dropout(1 - dropout_prob) | ||||
| self.dropout_prob = dropout_prob | self.dropout_prob = dropout_prob | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) | self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| @@ -444,7 +444,7 @@ class BertAttention(nn.Cell): | |||||
| if self.has_attention_mask: | if self.has_attention_mask: | ||||
| self.expand_dims = P.ExpandDims() | self.expand_dims = P.ExpandDims() | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.get_dtype = P.DType() | self.get_dtype = P.DType() | ||||
| if do_return_2d_tensor: | if do_return_2d_tensor: | ||||
| @@ -227,7 +227,7 @@ class EmbeddingPostprocessor(nn.Cell): | |||||
| frequency=frequency) | frequency=frequency) | ||||
| self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) | self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) | ||||
| self.layernorm = nn.LayerNorm((embedding_size,)) | self.layernorm = nn.LayerNorm((embedding_size,)) | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| def construct(self, token_type_ids, word_embeddings): | def construct(self, token_type_ids, word_embeddings): | ||||
| """construct of EmbeddingPostprocessor""" | """construct of EmbeddingPostprocessor""" | ||||
| @@ -275,7 +275,7 @@ class BertOutput(nn.Cell): | |||||
| batch_size=batch_size).to_float(compute_type) | batch_size=batch_size).to_float(compute_type) | ||||
| self.dropout = nn.Dropout(1 - dropout_prob) | self.dropout = nn.Dropout(1 - dropout_prob) | ||||
| self.dropout_prob = dropout_prob | self.dropout_prob = dropout_prob | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) | self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| @@ -522,7 +522,7 @@ class BertAttention(nn.Cell): | |||||
| if self.has_attention_mask: | if self.has_attention_mask: | ||||
| self.expand_dims = P.ExpandDims() | self.expand_dims = P.ExpandDims() | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.get_dtype = P.DType() | self.get_dtype = P.DType() | ||||
| if do_return_2d_tensor: | if do_return_2d_tensor: | ||||
| @@ -35,7 +35,7 @@ class LengthPenalty(nn.Cell): | |||||
| def __init__(self, weight=1.0, compute_type=mstype.float32): | def __init__(self, weight=1.0, compute_type=mstype.float32): | ||||
| super(LengthPenalty, self).__init__() | super(LengthPenalty, self).__init__() | ||||
| self.weight = weight | self.weight = weight | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.pow = P.Pow() | self.pow = P.Pow() | ||||
| self.div = P.RealDiv() | self.div = P.RealDiv() | ||||
| self.five = Tensor(5.0, mstype.float32) | self.five = Tensor(5.0, mstype.float32) | ||||
| @@ -188,7 +188,7 @@ class BeamSearchDecoder(nn.Cell): | |||||
| self.decoder = decoder | self.decoder = decoder | ||||
| self.is_using_while = is_using_while | self.is_using_while = is_using_while | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.expand = P.ExpandDims() | self.expand = P.ExpandDims() | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.shape_flat = (-1,) | self.shape_flat = (-1,) | ||||
| @@ -36,7 +36,7 @@ class LengthPenalty(nn.Cell): | |||||
| super(LengthPenalty, self).__init__() | super(LengthPenalty, self).__init__() | ||||
| self.weight = weight | self.weight = weight | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.pow = P.Pow() | self.pow = P.Pow() | ||||
| self.div = P.RealDiv() | self.div = P.RealDiv() | ||||
| @@ -178,7 +178,7 @@ class BeamSearchDecoder(nn.Cell): | |||||
| self.decoder = decoder | self.decoder = decoder | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.expand = P.ExpandDims() | self.expand = P.ExpandDims() | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.shape_flat = (-1,) | self.shape_flat = (-1,) | ||||
| @@ -138,7 +138,7 @@ class MultiHeadAttention(nn.Cell): | |||||
| if self.has_attention_mask: | if self.has_attention_mask: | ||||
| self.expand_dims = P.ExpandDims() | self.expand_dims = P.ExpandDims() | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| self.add = P.TensorAdd() | |||||
| self.add = P.Add() | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.get_dtype = P.DType() | self.get_dtype = P.DType() | ||||