diff --git a/mindspore/_extends/graph_kernel/expanders/layernorm.py b/mindspore/_extends/graph_kernel/expanders/layernorm.py index 2e7ff18500..72e8b4dd54 100644 --- a/mindspore/_extends/graph_kernel/expanders/layernorm.py +++ b/mindspore/_extends/graph_kernel/expanders/layernorm.py @@ -115,10 +115,11 @@ class LayerNorm(Expander): reduce_elts = 1.0 for i in reduce_axis: reduce_elts *= ori_shape_x[i] + # after reduced + ori_reduced_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) if input_x.data_format == DF.FRAC_NZ: reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis) - ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced mean_cof = 1.0 / reduce_elts mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) @@ -128,7 +129,7 @@ class LayerNorm(Expander): attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) mean = graph_builder.emit('Mul', [mean_red, mean_cof_v]) if input_x.data_format == DF.FRAC_NZ: - mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x}) + mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_reduced_shape_x}) # Calculate variance variance_sub = graph_builder.emit('Sub', [input_x, mean]) @@ -137,7 +138,7 @@ class LayerNorm(Expander): attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) if input_x.data_format == DF.FRAC_NZ: - variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x}) + variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_reduced_shape_x}) # Calculate normalize normalize_sub = graph_builder.emit('Sub', [input_x, mean])