Browse Source

batchnorm expander supports when first input is float16

pull/15741/head
looop5 4 years ago
parent
commit
24f441ba33
1 changed files with 12 additions and 1 deletions
  1. +12
    -1
      mindspore/_extends/graph_kernel/expanders/batchnorm.py

+ 12
- 1
mindspore/_extends/graph_kernel/expanders/batchnorm.py View File

@@ -33,6 +33,14 @@ class BatchNorm(Expander):
input_variance = self.inputs[4]
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])

input_x_ori_type = input_x.dtype
input_x_new_type = input_x.dtype
if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
input_mean.dtype == "float32" and input_variance.dtype == "float32":
input_x_new_type = "float32"
if input_x_new_type != input_x_ori_type:
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})

if self.attrs['is_training']:
reduce_axis = ()
shape_x = input_x.shape
@@ -109,7 +117,8 @@ class BatchNorm(Expander):
variance_res = graph_builder.emit(
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
attrs={'fake_output': True})

if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
# infer mode
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
@@ -128,4 +137,6 @@ class BatchNorm(Expander):
'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])})
x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
res_y = graph_builder.emit('Add', [input_offset, x_div])
if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, var_add, var_add, var_add, var_add

Loading…
Cancel
Save