|
|
|
@@ -115,10 +115,11 @@ class LayerNorm(Expander): |
|
|
|
reduce_elts = 1.0 |
|
|
|
for i in reduce_axis: |
|
|
|
reduce_elts *= ori_shape_x[i] |
|
|
|
# after reduced |
|
|
|
ori_reduced_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) |
|
|
|
|
|
|
|
if input_x.data_format == DF.FRAC_NZ: |
|
|
|
reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis) |
|
|
|
ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced |
|
|
|
|
|
|
|
mean_cof = 1.0 / reduce_elts |
|
|
|
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) |
|
|
|
@@ -128,7 +129,7 @@ class LayerNorm(Expander): |
|
|
|
attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) |
|
|
|
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v]) |
|
|
|
if input_x.data_format == DF.FRAC_NZ: |
|
|
|
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x}) |
|
|
|
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_reduced_shape_x}) |
|
|
|
|
|
|
|
# Calculate variance |
|
|
|
variance_sub = graph_builder.emit('Sub', [input_x, mean]) |
|
|
|
@@ -137,7 +138,7 @@ class LayerNorm(Expander): |
|
|
|
attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) |
|
|
|
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) |
|
|
|
if input_x.data_format == DF.FRAC_NZ: |
|
|
|
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x}) |
|
|
|
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_reduced_shape_x}) |
|
|
|
|
|
|
|
# Calculate normalize |
|
|
|
normalize_sub = graph_builder.emit('Sub', [input_x, mean]) |
|
|
|
|