Browse Source

bugfix in expanders of layernorm

pull/14886/head
wenfangpei 4 years ago
parent
commit
b9715db358
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      mindspore/_extends/graph_kernel/expanders/layernorm.py

+ 4
- 3
mindspore/_extends/graph_kernel/expanders/layernorm.py View File

@@ -115,10 +115,11 @@ class LayerNorm(Expander):
reduce_elts = 1.0
for i in reduce_axis:
reduce_elts *= ori_shape_x[i]
# after reduced
ori_reduced_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis)

if input_x.data_format == DF.FRAC_NZ:
reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis)
ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced

mean_cof = 1.0 / reduce_elts
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof)
@@ -128,7 +129,7 @@ class LayerNorm(Expander):
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
if input_x.data_format == DF.FRAC_NZ:
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x})
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_reduced_shape_x})

# Calculate variance
variance_sub = graph_builder.emit('Sub', [input_x, mean])
@@ -137,7 +138,7 @@ class LayerNorm(Expander):
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
if input_x.data_format == DF.FRAC_NZ:
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x})
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_reduced_shape_x})

# Calculate normalize
normalize_sub = graph_builder.emit('Sub', [input_x, mean])


Loading…
Cancel
Save