|
|
|
@@ -33,6 +33,14 @@ class BatchNorm(Expander): |
|
|
|
input_variance = self.inputs[4] |
|
|
|
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon']) |
|
|
|
|
|
|
|
input_x_ori_type = input_x.dtype |
|
|
|
input_x_new_type = input_x.dtype |
|
|
|
if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \ |
|
|
|
input_mean.dtype == "float32" and input_variance.dtype == "float32": |
|
|
|
input_x_new_type = "float32" |
|
|
|
if input_x_new_type != input_x_ori_type: |
|
|
|
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type}) |
|
|
|
|
|
|
|
if self.attrs['is_training']: |
|
|
|
reduce_axis = () |
|
|
|
shape_x = input_x.shape |
|
|
|
@@ -109,7 +117,8 @@ class BatchNorm(Expander): |
|
|
|
variance_res = graph_builder.emit( |
|
|
|
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance], |
|
|
|
attrs={'fake_output': True}) |
|
|
|
|
|
|
|
if input_x_new_type != input_x_ori_type: |
|
|
|
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type}) |
|
|
|
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec |
|
|
|
# infer mode |
|
|
|
if input_x.data_format in (DF.DEFAULT, DF.NCHW): |
|
|
|
@@ -128,4 +137,6 @@ class BatchNorm(Expander): |
|
|
|
'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])}) |
|
|
|
x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt]) |
|
|
|
res_y = graph_builder.emit('Add', [input_offset, x_div]) |
|
|
|
if input_x_new_type != input_x_ori_type: |
|
|
|
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type}) |
|
|
|
return res_y, var_add, var_add, var_add, var_add |