| @@ -11,4 +11,4 @@ | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | |||||
| # ============================================================================ | |||||
| @@ -13,7 +13,6 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _grad_ops as G | from mindspore.ops.operations import _grad_ops as G | ||||
| from mindspore.ops import Primitive | from mindspore.ops import Primitive | ||||
| @@ -13,7 +13,6 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _grad_ops as G | from mindspore.ops.operations import _grad_ops as G | ||||
| from mindspore.ops import Primitive | from mindspore.ops import Primitive | ||||
| @@ -61,8 +61,8 @@ def test_tbe_eltwise_fusion_1(tag): | |||||
| def after(x): | def after(x): | ||||
| fusion = Fusion_relu_relu(x) | fusion = Fusion_relu_relu(x) | ||||
| res = Cast(fusion) | res = Cast(fusion) | ||||
| tuple = make_tuple(res) | |||||
| return tuple | |||||
| output = make_tuple(res) | |||||
| return output | |||||
| return fns[tag] | return fns[tag] | ||||
| @@ -86,8 +86,8 @@ def test_tbe_eltwise_fusion_2(tag): | |||||
| def after(x, y): | def after(x, y): | ||||
| fusion = Fusion_biasadd(x, y) | fusion = Fusion_biasadd(x, y) | ||||
| res = Cast(fusion) | res = Cast(fusion) | ||||
| tuple = make_tuple(res) | |||||
| return tuple | |||||
| output = make_tuple(res) | |||||
| return output | |||||
| return fns[tag] | return fns[tag] | ||||
| @@ -111,8 +111,8 @@ def test_tbe_reduce_eltwise_fusion(tag): | |||||
| def after(x): | def after(x): | ||||
| fusion = Fusion_biasaddgrad(x) | fusion = Fusion_biasaddgrad(x) | ||||
| res = Cast(fusion) | res = Cast(fusion) | ||||
| tuple = make_tuple(res) | |||||
| return tuple | |||||
| output = make_tuple(res) | |||||
| return output | |||||
| return fns[tag] | return fns[tag] | ||||
| @@ -131,8 +131,8 @@ def test_conv_singlein_fusion(tag): | |||||
| def after(x, y): | def after(x, y): | ||||
| fusion = Fusion(x, y) | fusion = Fusion(x, y) | ||||
| res = Cast(fusion) | res = Cast(fusion) | ||||
| tuple = make_tuple(res) | |||||
| return tuple | |||||
| output = make_tuple(res) | |||||
| return output | |||||
| return fns[tag] | return fns[tag] | ||||
| @@ -151,7 +151,7 @@ def test_tbe_matmul_eltwise_fusion(tag): | |||||
| def after(x, y): | def after(x, y): | ||||
| fusion = Fusion_matmul_relu(x, y) | fusion = Fusion_matmul_relu(x, y) | ||||
| res = Cast(fusion) | res = Cast(fusion) | ||||
| tuple = make_tuple(res) | |||||
| return tuple | |||||
| output = make_tuple(res) | |||||
| return output | |||||
| return fns[tag] | return fns[tag] | ||||
| @@ -40,17 +40,17 @@ def test_clip_by_norm_no_div_square_sum_fusion(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(input, constant_select, constant_greater, constant_maximum): | |||||
| greater_output = greater(input, constant_greater) | |||||
| res = select(greater_output, input, constant_select) | |||||
| def before(x, constant_select, constant_greater, constant_maximum): | |||||
| greater_output = greater(x, constant_greater) | |||||
| res = select(greater_output, x, constant_select) | |||||
| res = sqrt(res) | res = sqrt(res) | ||||
| res = select(greater_output, res, input) | |||||
| res = select(greater_output, res, x) | |||||
| res = maximum(res, constant_maximum) | res = maximum(res, constant_maximum) | ||||
| return res | return res | ||||
| @fns | @fns | ||||
| def after(input, constant_select, constant_greater, constant_maximum): | |||||
| res = clip_by_norm_no_div_square_sum(input, constant_select, constant_greater, constant_maximum) | |||||
| def after(x, constant_select, constant_greater, constant_maximum): | |||||
| res = clip_by_norm_no_div_square_sum(x, constant_select, constant_greater, constant_maximum) | |||||
| return make_tuple(res) | return make_tuple(res) | ||||
| return fns[tag] | return fns[tag] | ||||
| @@ -38,6 +38,7 @@ depth = Tensor(2, mstype.int32) | |||||
| shape = (2, 4, 2, 2) | shape = (2, 4, 2, 2) | ||||
| dropout_gen_mask = P.DropoutGenMask() | dropout_gen_mask = P.DropoutGenMask() | ||||
| class FnDict: | class FnDict: | ||||
| def __init__(self): | def __init__(self): | ||||
| self.fnDict = {} | self.fnDict = {} | ||||
| @@ -114,7 +115,7 @@ def test_convert_strided_slice_grad_input_to_attr(tag): | |||||
| @fns | @fns | ||||
| def before(x): | def before(x): | ||||
| return stridedslicegrad(x, (16, 128, 1024), (0, 0 , 0), (16, 1, 1024), (1, 1,1)) | |||||
| return stridedslicegrad(x, (16, 128, 1024), (0, 0, 0), (16, 1, 1024), (1, 1, 1)) | |||||
| @fns | @fns | ||||
| def after(x): | def after(x): | ||||
| @@ -12,12 +12,10 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| import numpy as np | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import Primitive | from mindspore.ops import Primitive | ||||
| import mindspore as ms | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| import numpy as np | |||||
| make_tuple = Primitive('make_tuple') | make_tuple = Primitive('make_tuple') | ||||
| concat = P.Concat() | concat = P.Concat() | ||||
| @@ -51,7 +49,7 @@ def test_convert_tuple_input_to_dynamic_input(tag): | |||||
| def after(x): | def after(x): | ||||
| res = concat(t1, t2) | res = concat(t1, t2) | ||||
| res = add(x, res) | res = add(x, res) | ||||
| res = make_tuple(res); | |||||
| res = make_tuple(res) | |||||
| return res | return res | ||||
| return fns[tag] | return fns[tag] | ||||
| @@ -14,16 +14,13 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import Primitive | from mindspore.ops import Primitive | ||||
| import mindspore as ms | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.common.tensor import Tensor | |||||
| import numpy as np | |||||
| make_tuple = Primitive('make_tuple') | make_tuple = Primitive('make_tuple') | ||||
| tuple_get_item = Primitive("tuple_getitem"); | |||||
| LSTM = P.LSTM(input_size=10,hidden_size=2,num_layers=1,has_bias=True,bidirectional=False,dropout=0.0) | |||||
| tuple_get_item = Primitive("tuple_getitem") | |||||
| LSTM = P.LSTM(input_size=10, hidden_size=2, num_layers=1, has_bias=True, bidirectional=False, dropout=0.0) | |||||
| add = P.TensorAdd() | add = P.TensorAdd() | ||||
| class FnDict: | class FnDict: | ||||
| def __init__(self): | def __init__(self): | ||||
| self.fnDict = {} | self.fnDict = {} | ||||
| @@ -48,7 +45,7 @@ def test_convert_tuple_output_to_maketuple(tag): | |||||
| res = LSTM(x, h, c, w) | res = LSTM(x, h, c, w) | ||||
| res = make_tuple( | res = make_tuple( | ||||
| make_tuple(tuple_get_item(res, 0), tuple_get_item(res, 1), tuple_get_item(res, 2), tuple_get_item(res, 3), | make_tuple(tuple_get_item(res, 0), tuple_get_item(res, 1), tuple_get_item(res, 2), tuple_get_item(res, 3), | ||||
| tuple_get_item(res, 4))); | |||||
| tuple_get_item(res, 4))) | |||||
| return res | return res | ||||
| return fns[tag] | return fns[tag] | ||||
| @@ -40,8 +40,8 @@ def test_eliminate_5to4_4to5(tag): | |||||
| @fns | @fns | ||||
| def before(x, y): | def before(x, y): | ||||
| sum = add(x, y) | |||||
| res = sub(sum, y) | |||||
| sum_add = add(x, y) | |||||
| res = sub(sum_add, y) | |||||
| output = make_tuple(res) | output = make_tuple(res) | ||||
| return output | return output | ||||
| @@ -50,8 +50,8 @@ def test_eliminate_5to4_4to5(tag): | |||||
| new_x_sum = transdata(x) | new_x_sum = transdata(x) | ||||
| new_y_sum = transdata(y) | new_y_sum = transdata(y) | ||||
| new_y_sum2 = transdata(y) | new_y_sum2 = transdata(y) | ||||
| sum = add(new_x_sum, new_y_sum) | |||||
| sum_5to4 = transdata(sum) | |||||
| sum_add = add(new_x_sum, new_y_sum) | |||||
| sum_5to4 = transdata(sum_add) | |||||
| sum_4to5 = transdata(sum_5to4) | sum_4to5 = transdata(sum_5to4) | ||||
| res = sub(sum_4to5, new_y_sum2) | res = sub(sum_4to5, new_y_sum2) | ||||
| output = transdata(res) | output = transdata(res) | ||||
| @@ -64,8 +64,8 @@ def test_eliminate_5to4_4to5(tag): | |||||
| new_x_sum = transdata(x) | new_x_sum = transdata(x) | ||||
| new_y_sum = transdata(y) | new_y_sum = transdata(y) | ||||
| new_y_diff = transdata(y) | new_y_diff = transdata(y) | ||||
| sum = add(new_x_sum, new_y_sum) | |||||
| res = sub(sum, new_y_diff) | |||||
| sum_add = add(new_x_sum, new_y_sum) | |||||
| res = sub(sum_add, new_y_diff) | |||||
| output = transdata(res) | output = transdata(res) | ||||
| new_output = make_tuple(output) | new_output = make_tuple(output) | ||||
| ret = make_tuple(new_output) | ret = make_tuple(new_output) | ||||
| @@ -79,8 +79,8 @@ def test_eliminate_cast(tag): | |||||
| @fns | @fns | ||||
| def before(x, y): | def before(x, y): | ||||
| sum = add(x, y) | |||||
| res = sub(sum, y) | |||||
| sum_add = add(x, y) | |||||
| res = sub(sum_add, y) | |||||
| output = make_tuple(res) | output = make_tuple(res) | ||||
| return output | return output | ||||
| @@ -89,8 +89,8 @@ def test_eliminate_cast(tag): | |||||
| new_x_sum = cast(x) | new_x_sum = cast(x) | ||||
| new_y_sum = cast(y) | new_y_sum = cast(y) | ||||
| new_y_sum2 = cast(y) | new_y_sum2 = cast(y) | ||||
| sum = add(new_x_sum, new_y_sum) | |||||
| sum_cast1 = cast(sum) | |||||
| sum_add = add(new_x_sum, new_y_sum) | |||||
| sum_cast1 = cast(sum_add) | |||||
| sum_cast2 = cast(sum_cast1) | sum_cast2 = cast(sum_cast1) | ||||
| res = sub(sum_cast2, new_y_sum2) | res = sub(sum_cast2, new_y_sum2) | ||||
| output = cast(res) | output = cast(res) | ||||
| @@ -103,8 +103,8 @@ def test_eliminate_cast(tag): | |||||
| new_x_sum = cast(x) | new_x_sum = cast(x) | ||||
| new_y_sum = cast(y) | new_y_sum = cast(y) | ||||
| new_y_diff = cast(y) | new_y_diff = cast(y) | ||||
| sum = add(new_x_sum, new_y_sum) | |||||
| res = sub(sum, new_y_diff) | |||||
| sum_add = add(new_x_sum, new_y_sum) | |||||
| res = sub(sum_add, new_y_diff) | |||||
| output = cast(res) | output = cast(res) | ||||
| new_output = make_tuple(output) | new_output = make_tuple(output) | ||||
| ret = make_tuple(new_output) | ret = make_tuple(new_output) | ||||
| @@ -118,8 +118,8 @@ def test_eliminate_5to4_depend_4to5(tag): | |||||
| @fns | @fns | ||||
| def before(x, y): | def before(x, y): | ||||
| sum = add(x, y) | |||||
| sum_depend = depend(sum, x) | |||||
| sum_add = add(x, y) | |||||
| sum_depend = depend(sum_add, x) | |||||
| res = sub(sum_depend, y) | res = sub(sum_depend, y) | ||||
| output = make_tuple(res) | output = make_tuple(res) | ||||
| return output | return output | ||||
| @@ -128,8 +128,8 @@ def test_eliminate_5to4_depend_4to5(tag): | |||||
| def after1(x, y): | def after1(x, y): | ||||
| new_x_sum = transdata(x) | new_x_sum = transdata(x) | ||||
| new_y_sum = transdata(y) | new_y_sum = transdata(y) | ||||
| sum = add(new_x_sum, new_y_sum) | |||||
| sum_trans = transdata(sum) | |||||
| sum_add = add(new_x_sum, new_y_sum) | |||||
| sum_trans = transdata(sum_add) | |||||
| depend_between_trans = depend(sum_trans, x) | depend_between_trans = depend(sum_trans, x) | ||||
| depend_trans = transdata(depend_between_trans) | depend_trans = transdata(depend_between_trans) | ||||
| new_y_diff = transdata(y) | new_y_diff = transdata(y) | ||||
| @@ -143,8 +143,8 @@ def test_eliminate_5to4_depend_4to5(tag): | |||||
| def after2(x, y): | def after2(x, y): | ||||
| new_x_sum = transdata(x) | new_x_sum = transdata(x) | ||||
| new_y_sum = transdata(y) | new_y_sum = transdata(y) | ||||
| sum = add(new_x_sum, new_y_sum) | |||||
| depend_op = depend(sum, x) | |||||
| sum_add = add(new_x_sum, new_y_sum) | |||||
| depend_op = depend(sum_add, x) | |||||
| new_y_diff = transdata(y) | new_y_diff = transdata(y) | ||||
| res = sub(depend_op, new_y_diff) | res = sub(depend_op, new_y_diff) | ||||
| output = transdata(res) | output = transdata(res) | ||||
| @@ -160,8 +160,8 @@ def test_eliminate_cast_depend_cast(tag): | |||||
| @fns | @fns | ||||
| def before(x, y): | def before(x, y): | ||||
| sum = add(x, y) | |||||
| sum_depend = depend(sum, x) | |||||
| sum_add = add(x, y) | |||||
| sum_depend = depend(sum_add, x) | |||||
| sum_depend2 = depend(sum_depend, x) | sum_depend2 = depend(sum_depend, x) | ||||
| sum_depend3 = depend(sum_depend2, x) | sum_depend3 = depend(sum_depend2, x) | ||||
| res = sub(sum_depend3, y) | res = sub(sum_depend3, y) | ||||
| @@ -172,8 +172,8 @@ def test_eliminate_cast_depend_cast(tag): | |||||
| def after1(x, y): | def after1(x, y): | ||||
| new_x_sum = cast(x) | new_x_sum = cast(x) | ||||
| new_y_sum = cast(y) | new_y_sum = cast(y) | ||||
| sum = add(new_x_sum, new_y_sum) | |||||
| sum_cast = cast(sum) | |||||
| sum_add = add(new_x_sum, new_y_sum) | |||||
| sum_cast = cast(sum_add) | |||||
| depend_between_cast = depend(sum_cast, x) | depend_between_cast = depend(sum_cast, x) | ||||
| depend_between_cast2 = depend(depend_between_cast, x) | depend_between_cast2 = depend(depend_between_cast, x) | ||||
| depend_between_cast3 = depend(depend_between_cast2, x) | depend_between_cast3 = depend(depend_between_cast2, x) | ||||
| @@ -189,8 +189,8 @@ def test_eliminate_cast_depend_cast(tag): | |||||
| def after2(x, y): | def after2(x, y): | ||||
| new_x_sum = cast(x) | new_x_sum = cast(x) | ||||
| new_y_sum = cast(y) | new_y_sum = cast(y) | ||||
| sum = add(new_x_sum, new_y_sum) | |||||
| depend_op = depend(sum, x) | |||||
| sum_add = add(new_x_sum, new_y_sum) | |||||
| depend_op = depend(sum_add, x) | |||||
| depend_op2 = depend(depend_op, x) | depend_op2 = depend(depend_op, x) | ||||
| depend_op3 = depend(depend_op2, x) | depend_op3 = depend(depend_op2, x) | ||||
| new_y_diff = cast(y) | new_y_diff = cast(y) | ||||
| @@ -201,4 +201,3 @@ def test_eliminate_cast_depend_cast(tag): | |||||
| return ret | return ret | ||||
| return fns[tag] | return fns[tag] | ||||
| @@ -40,14 +40,14 @@ def test_getnext_memcpy_elimination(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(x): | |||||
| def before(): | |||||
| res = get_next() | res = get_next() | ||||
| res = memcpy_async_attr(res) | res = memcpy_async_attr(res) | ||||
| res = cast(res) | res = cast(res) | ||||
| return res | return res | ||||
| @fns | @fns | ||||
| def after(x): | |||||
| def after(): | |||||
| res = get_next() | res = get_next() | ||||
| res = cast(res) | res = cast(res) | ||||
| return res | return res | ||||
| @@ -59,14 +59,14 @@ def test_getnext_memcpy_elimination_no_attr(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(x): | |||||
| def before(): | |||||
| res = get_next() | res = get_next() | ||||
| res = memcpy_async(res) | res = memcpy_async(res) | ||||
| res = cast(res) | res = cast(res) | ||||
| return res | return res | ||||
| @fns | @fns | ||||
| def after(x): | |||||
| def after(): | |||||
| res = get_next() | res = get_next() | ||||
| res = memcpy_async(res) | res = memcpy_async(res) | ||||
| res = cast(res) | res = cast(res) | ||||
| @@ -79,7 +79,7 @@ def test_getnext_memcpy_elimination_memcpy_multi_users(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(x): | |||||
| def before(): | |||||
| res = get_next() | res = get_next() | ||||
| memcpy_out = memcpy_async_attr(res) | memcpy_out = memcpy_async_attr(res) | ||||
| res = cast(memcpy_out) | res = cast(memcpy_out) | ||||
| @@ -87,7 +87,7 @@ def test_getnext_memcpy_elimination_memcpy_multi_users(tag): | |||||
| return res | return res | ||||
| @fns | @fns | ||||
| def after(x): | |||||
| def after(): | |||||
| res = get_next() | res = get_next() | ||||
| memcpy_out = memcpy_async_attr(res) | memcpy_out = memcpy_async_attr(res) | ||||
| res = cast(memcpy_out) | res = cast(memcpy_out) | ||||
| @@ -101,14 +101,14 @@ def test_getnext_memcpy_elimination_next_multi_inputs(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(x): | |||||
| def before(): | |||||
| res = get_next() | res = get_next() | ||||
| memcpy_out = memcpy_async_attr(res) | memcpy_out = memcpy_async_attr(res) | ||||
| res = add(memcpy_out, res) | res = add(memcpy_out, res) | ||||
| return res | return res | ||||
| @fns | @fns | ||||
| def after(x): | |||||
| def after(): | |||||
| res = get_next() | res = get_next() | ||||
| memcpy_out = memcpy_async_attr(res) | memcpy_out = memcpy_async_attr(res) | ||||
| res = add(memcpy_out, res) | res = add(memcpy_out, res) | ||||
| @@ -127,14 +127,14 @@ def test_eliminate_depend_input2(tag): | |||||
| def before(x, y, z): | def before(x, y, z): | ||||
| new_z = four2five(z) | new_z = four2five(z) | ||||
| depend_intput = depend(y, new_z) | depend_intput = depend(y, new_z) | ||||
| sum = add(x, depend_intput) | |||||
| return sum | |||||
| sum_add = add(x, depend_intput) | |||||
| return sum_add | |||||
| @fns | @fns | ||||
| def after(x, y, z): | def after(x, y, z): | ||||
| depend_intput = depend(y, z) | depend_intput = depend(y, z) | ||||
| sum = add(x, depend_intput) | |||||
| return sum | |||||
| sum_add = add(x, depend_intput) | |||||
| return sum_add | |||||
| return fns[tag] | return fns[tag] | ||||
| @@ -144,8 +144,8 @@ def test_opt_match(tag): | |||||
| @fns | @fns | ||||
| def graph1(x, y): | def graph1(x, y): | ||||
| sum = add(x, y) | |||||
| output = make_tuple(sum) | |||||
| sum_add = add(x, y) | |||||
| output = make_tuple(sum_add) | |||||
| return output | return output | ||||
| @fns | @fns | ||||
| @@ -178,4 +178,3 @@ def test_func_graph_cse(tag): | |||||
| return d | return d | ||||
| return fns[tag] | return fns[tag] | ||||
| @@ -66,17 +66,17 @@ def test_lamb_next_mv_rule(tag): | |||||
| def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | ||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | ||||
| lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6, | lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6, | ||||
| constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||||
| constant_add2_y) | |||||
| constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, | |||||
| constant_mul4_x, constant_add2_y) | |||||
| outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1), | outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1), | ||||
| tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3)) | tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3)) | ||||
| output = tuple_getitem(outputs, 0) | output = tuple_getitem(outputs, 0) | ||||
| return make_tuple(output) | return make_tuple(output) | ||||
| @fns | @fns | ||||
| def before_unmatched_real_div4(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| def before_unmatched_real_div4(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, | |||||
| constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||||
| constant_add2_y): | |||||
| mul0 = Mul(constant_mul0_x, input4) | mul0 = Mul(constant_mul0_x, input4) | ||||
| mul1 = Mul(constant_mul1_sub, input3) | mul1 = Mul(constant_mul1_sub, input3) | ||||
| add0 = Add(mul0, mul1) | add0 = Add(mul0, mul1) | ||||
| @@ -98,8 +98,9 @@ def test_lamb_next_mv_rule(tag): | |||||
| return output | return output | ||||
| @fns | @fns | ||||
| def before_unmatched_real_div0(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| def before_unmatched_real_div0(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, | |||||
| constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||||
| constant_add2_y): | |||||
| mul0 = Mul(constant_mul0_x, input4) | mul0 = Mul(constant_mul0_x, input4) | ||||
| mul1 = Mul(constant_mul1_sub, input3) | mul1 = Mul(constant_mul1_sub, input3) | ||||
| add0 = Add(mul0, mul1) | add0 = Add(mul0, mul1) | ||||
| @@ -121,8 +122,9 @@ def test_lamb_next_mv_rule(tag): | |||||
| return output | return output | ||||
| @fns | @fns | ||||
| def before_unmatched_real_div1(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| def before_unmatched_real_div1(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, | |||||
| constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||||
| constant_add2_y): | |||||
| mul0 = Mul(constant_mul0_x, input4) | mul0 = Mul(constant_mul0_x, input4) | ||||
| mul1 = Mul(constant_mul1_sub, input3) | mul1 = Mul(constant_mul1_sub, input3) | ||||
| add0 = Add(mul0, mul1) | add0 = Add(mul0, mul1) | ||||
| @@ -144,8 +146,9 @@ def test_lamb_next_mv_rule(tag): | |||||
| return output | return output | ||||
| @fns | @fns | ||||
| def before_unmatched_real_div2(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| def before_unmatched_real_div2(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, | |||||
| constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||||
| constant_add2_y): | |||||
| mul0 = Mul(constant_mul0_x, input4) | mul0 = Mul(constant_mul0_x, input4) | ||||
| mul1 = Mul(constant_mul1_sub, input3) | mul1 = Mul(constant_mul1_sub, input3) | ||||
| add0 = Add(mul0, mul1) | add0 = Add(mul0, mul1) | ||||
| @@ -13,7 +13,6 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _grad_ops as G | from mindspore.ops.operations import _grad_ops as G | ||||
| from mindspore.ops import Primitive | from mindspore.ops import Primitive | ||||
| @@ -82,8 +82,8 @@ def test_eliminate_cast_op(tag): | |||||
| @fns | @fns | ||||
| def before(x, y): | def before(x, y): | ||||
| sum = addn((x, y)) | |||||
| sum_depend = depend(sum, addn((x, y))) | |||||
| sum_add = addn((x, y)) | |||||
| sum_depend = depend(sum_add, addn((x, y))) | |||||
| diff = sub(x, y) | diff = sub(x, y) | ||||
| res = mul(sum_depend, diff) | res = mul(sum_depend, diff) | ||||
| return res | return res | ||||
| @@ -92,8 +92,8 @@ def test_eliminate_cast_op(tag): | |||||
| def after1(x, y): | def after1(x, y): | ||||
| new_x_sum = cast(x) | new_x_sum = cast(x) | ||||
| new_y_sum = cast(y) | new_y_sum = cast(y) | ||||
| sum = addn(new_x_sum, new_y_sum) | |||||
| sum_cast = cast(sum) | |||||
| sum_add = addn(new_x_sum, new_y_sum) | |||||
| sum_cast = cast(sum_add) | |||||
| new_x_depend = cast(x) | new_x_depend = cast(x) | ||||
| new_y_depend = cast(y) | new_y_depend = cast(y) | ||||
| sum_depend = addn(new_x_depend, new_y_depend) | sum_depend = addn(new_x_depend, new_y_depend) | ||||
| @@ -114,12 +114,12 @@ def test_eliminate_cast_op(tag): | |||||
| def after2(x, y): | def after2(x, y): | ||||
| new_x_sum = cast(x) | new_x_sum = cast(x) | ||||
| new_y_sum = cast(y) | new_y_sum = cast(y) | ||||
| sum = addn(new_x_sum, new_y_sum) | |||||
| sum_add = addn(new_x_sum, new_y_sum) | |||||
| new_x_depend = cast(x) | new_x_depend = cast(x) | ||||
| new_y_depend = cast(y) | new_y_depend = cast(y) | ||||
| sum_depend = addn(new_x_depend, new_y_depend) | sum_depend = addn(new_x_depend, new_y_depend) | ||||
| sum_depend_cast = cast(sum_depend) | sum_depend_cast = cast(sum_depend) | ||||
| depend_between_cast = depend(sum, sum_depend_cast) | |||||
| depend_between_cast = depend(sum_add, sum_depend_cast) | |||||
| new_x_diff = cast(x) | new_x_diff = cast(x) | ||||
| new_y_diff = cast(y) | new_y_diff = cast(y) | ||||
| diff = sub(new_x_diff, new_y_diff) | diff = sub(new_x_diff, new_y_diff) | ||||
| @@ -156,8 +156,8 @@ def test_eliminate_cast_new(tag): | |||||
| @fns | @fns | ||||
| def before(x, y): | def before(x, y): | ||||
| sum = add(x, y) | |||||
| res = sub(sum, y) | |||||
| sum_add = add(x, y) | |||||
| res = sub(sum_add, y) | |||||
| output = make_tuple(res) | output = make_tuple(res) | ||||
| return output | return output | ||||
| @@ -166,8 +166,8 @@ def test_eliminate_cast_new(tag): | |||||
| new_x_sum = cast(x) | new_x_sum = cast(x) | ||||
| new_y_sum = cast(y) | new_y_sum = cast(y) | ||||
| new_y_sum2 = cast(y) | new_y_sum2 = cast(y) | ||||
| sum = add(new_x_sum, new_y_sum) | |||||
| sum_5to4 = cast(sum) | |||||
| sum_add = add(new_x_sum, new_y_sum) | |||||
| sum_5to4 = cast(sum_add) | |||||
| sum_4to5 = cast(sum_5to4) | sum_4to5 = cast(sum_5to4) | ||||
| res = sub(sum_4to5, new_y_sum2) | res = sub(sum_4to5, new_y_sum2) | ||||
| output = cast(res) | output = cast(res) | ||||
| @@ -179,11 +179,10 @@ def test_eliminate_cast_new(tag): | |||||
| new_x_sum = cast(x) | new_x_sum = cast(x) | ||||
| new_y_sum = cast(y) | new_y_sum = cast(y) | ||||
| new_y_diff = cast(y) | new_y_diff = cast(y) | ||||
| sum = add(new_x_sum, new_y_sum) | |||||
| res = sub(sum, new_y_diff) | |||||
| sum_add = add(new_x_sum, new_y_sum) | |||||
| res = sub(sum_add, new_y_diff) | |||||
| output = cast(res) | output = cast(res) | ||||
| new_output = make_tuple(output) | new_output = make_tuple(output) | ||||
| return new_output | return new_output | ||||
| return fns[tag] | return fns[tag] | ||||
| @@ -39,14 +39,14 @@ def test_optimize_dependence(tag): | |||||
| def before(x, y, z): | def before(x, y, z): | ||||
| new_z = TransData(z) | new_z = TransData(z) | ||||
| depend_intput = depend(y, new_z) | depend_intput = depend(y, new_z) | ||||
| sum = add(x, depend_intput) | |||||
| return sum | |||||
| sum_add = add(x, depend_intput) | |||||
| return sum_add | |||||
| @fns | @fns | ||||
| def after(x, y, z): | def after(x, y, z): | ||||
| depend_intput = depend(y, z) | depend_intput = depend(y, z) | ||||
| sum = add(x, depend_intput) | |||||
| return sum | |||||
| sum_add = add(x, depend_intput) | |||||
| return sum_add | |||||
| return fns[tag] | return fns[tag] | ||||
| @@ -58,14 +58,14 @@ def test_optimize_dependence_with_make_tuple(tag): | |||||
| def before(x, y, a, b): | def before(x, y, a, b): | ||||
| z = make_tuple(TransData(a), TransData(b)) | z = make_tuple(TransData(a), TransData(b)) | ||||
| depend_intput = depend(y, z) | depend_intput = depend(y, z) | ||||
| sum = add(x, depend_intput) | |||||
| return sum | |||||
| sum_add = add(x, depend_intput) | |||||
| return sum_add | |||||
| @fns | @fns | ||||
| def after(x, y, a, b): | def after(x, y, a, b): | ||||
| z = make_tuple(a, b) | z = make_tuple(a, b) | ||||
| depend_intput = depend(y, z) | depend_intput = depend(y, z) | ||||
| sum = add(x, depend_intput) | |||||
| return sum | |||||
| sum_add = add(x, depend_intput) | |||||
| return sum_add | |||||
| return fns[tag] | return fns[tag] | ||||
| @@ -34,8 +34,8 @@ def test_topk_split(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(input): | |||||
| topk = TopK(input, 2) | |||||
| def before(x): | |||||
| topk = TopK(x, 2) | |||||
| output = tuple_getitem(topk, 0) | output = tuple_getitem(topk, 0) | ||||
| return output | return output | ||||
| @@ -22,9 +22,10 @@ make_tuple = Primitive('make_tuple') | |||||
| four2five = Primitive('Four2Five') | four2five = Primitive('Four2Five') | ||||
| five2four = Primitive('Five2Four') | five2four = Primitive('Five2Four') | ||||
| transdata = Primitive("TransData") | transdata = Primitive("TransData") | ||||
| transpose = Primitive("Transpose") | |||||
| transpose = Primitive("Transpose") | |||||
| Transpose = P.Transpose() | Transpose = P.Transpose() | ||||
| class FnDict: | class FnDict: | ||||
| def __init__(self): | def __init__(self): | ||||
| self.fnDict = {} | self.fnDict = {} | ||||
| @@ -40,33 +41,35 @@ def test_transdata_split_fraz_nchw(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(input): | |||||
| res = Transpose(input, (1, 0, 2, 3)) | |||||
| def before(x): | |||||
| res = Transpose(x, (1, 0, 2, 3)) | |||||
| return res | return res | ||||
| @fns | @fns | ||||
| def after(input): | |||||
| res = transpose(input) | |||||
| output = transdata(res) | |||||
| def after(x): | |||||
| res = transpose(x) | |||||
| output = transdata(res) | |||||
| output = transpose(output) | output = transpose(output) | ||||
| res = make_tuple(output) | res = make_tuple(output) | ||||
| return res | return res | ||||
| return fns[tag] | return fns[tag] | ||||
| def test_transdata_split_nchw_fraz(tag): | def test_transdata_split_nchw_fraz(tag): | ||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(input): | |||||
| res = Transpose(input, (1, 0, 2, 3)) | |||||
| def before(x): | |||||
| res = Transpose(x, (1, 0, 2, 3)) | |||||
| return res | return res | ||||
| @fns | @fns | ||||
| def after(input): | |||||
| res = transpose(input) | |||||
| output = transdata(res) | |||||
| def after(x): | |||||
| res = transpose(x) | |||||
| output = transdata(res) | |||||
| output = transpose(output) | output = transpose(output) | ||||
| res = make_tuple(output) | res = make_tuple(output) | ||||
| return res | return res | ||||
| return fns[tag] | |||||
| return fns[tag] | |||||
| @@ -35,14 +35,14 @@ def test_transpose_reshape_fusion(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(input): | |||||
| transpose = Transpose(input, (1, 0, 2, 3)) | |||||
| def before(x): | |||||
| transpose = Transpose(x, (1, 0, 2, 3)) | |||||
| reshape = Reshape(transpose, (2, 4, 8, 16)) | reshape = Reshape(transpose, (2, 4, 8, 16)) | ||||
| return reshape | return reshape | ||||
| @fns | @fns | ||||
| def after(input): | |||||
| confusion = ConfusionTransposeD(input) | |||||
| def after(x): | |||||
| confusion = ConfusionTransposeD(x) | |||||
| res = make_tuple(confusion) | res = make_tuple(confusion) | ||||
| return res | return res | ||||
| @@ -37,13 +37,13 @@ def test_transpose_transdata_fusion(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(input): | |||||
| res = Transpose(input, (1, 0, 2, 3)) | |||||
| def before(x): | |||||
| res = Transpose(x, (1, 0, 2, 3)) | |||||
| return res | return res | ||||
| @fns | @fns | ||||
| def after(input): | |||||
| output = transdata(input) | |||||
| def after(x): | |||||
| output = transdata(x) | |||||
| res = make_tuple(output) | res = make_tuple(output) | ||||
| return res | return res | ||||