|
|
|
@@ -24,6 +24,7 @@ make_tuple = Primitive('make_tuple') |
|
|
|
tuple_getitem = Primitive('tuple_getitem') |
|
|
|
depend = Primitive('depend') |
|
|
|
BatchNorm = P.BatchNorm() |
|
|
|
Cast = P.Cast() |
|
|
|
BNTrainingReduce = Primitive('BNTrainingReduce') |
|
|
|
BNTrainingUpdate = Primitive('BNTrainingUpdate') |
|
|
|
constant0 = Tensor(0.1, mstype.float32) |
|
|
|
@@ -59,6 +60,21 @@ def test_fused_batch_norm_fusion(tag): |
|
|
|
output = tuple_getitem(outputs, 0) |
|
|
|
return output |
|
|
|
|
|
|
|
@fns |
|
|
|
def before_mix_precision(input0, input1, input2, input3, input4, var0, var1): |
|
|
|
batch_norm = BatchNorm(input0, input1, input2, input3, input4) |
|
|
|
sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1)) |
|
|
|
sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) |
|
|
|
mul0 = Mul(sub0, constant0) |
|
|
|
mul1 = Mul(sub1, constant1) |
|
|
|
assign_sub0 = AssignSub(var0, Cast(mul0, mstype.float32)) |
|
|
|
assign_sub1 = AssignSub(var1, Cast(mul1, mstype.float32)) |
|
|
|
depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0) |
|
|
|
depend1 = depend(depend0, assign_sub1) |
|
|
|
outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4)) |
|
|
|
output = tuple_getitem(outputs, 0) |
|
|
|
return output |
|
|
|
|
|
|
|
@fns |
|
|
|
def after(input0, input1, input2, input3, input4, var0, var1): |
|
|
|
bn_training_reduce = BNTrainingReduce(input0) |
|
|
|
|