| @@ -115,10 +115,11 @@ class LayerNorm(Expander): | |||||
| reduce_elts = 1.0 | reduce_elts = 1.0 | ||||
| for i in reduce_axis: | for i in reduce_axis: | ||||
| reduce_elts *= ori_shape_x[i] | reduce_elts *= ori_shape_x[i] | ||||
| # after reduced | |||||
| ori_reduced_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) | |||||
| if input_x.data_format == DF.FRAC_NZ: | if input_x.data_format == DF.FRAC_NZ: | ||||
| reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis) | reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis) | ||||
| ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced | |||||
| mean_cof = 1.0 / reduce_elts | mean_cof = 1.0 / reduce_elts | ||||
| mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) | mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) | ||||
| @@ -128,7 +129,7 @@ class LayerNorm(Expander): | |||||
| attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) | attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) | ||||
| mean = graph_builder.emit('Mul', [mean_red, mean_cof_v]) | mean = graph_builder.emit('Mul', [mean_red, mean_cof_v]) | ||||
| if input_x.data_format == DF.FRAC_NZ: | if input_x.data_format == DF.FRAC_NZ: | ||||
| mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x}) | |||||
| mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_reduced_shape_x}) | |||||
| # Calculate variance | # Calculate variance | ||||
| variance_sub = graph_builder.emit('Sub', [input_x, mean]) | variance_sub = graph_builder.emit('Sub', [input_x, mean]) | ||||
| @@ -137,7 +138,7 @@ class LayerNorm(Expander): | |||||
| attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) | attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) | ||||
| variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) | variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) | ||||
| if input_x.data_format == DF.FRAC_NZ: | if input_x.data_format == DF.FRAC_NZ: | ||||
| variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x}) | |||||
| variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_reduced_shape_x}) | |||||
| # Calculate normalize | # Calculate normalize | ||||
| normalize_sub = graph_builder.emit('Sub', [input_x, mean]) | normalize_sub = graph_builder.emit('Sub', [input_x, mean]) | ||||