|
|
|
@@ -207,8 +207,10 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
|
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size - 1)}); |
|
|
|
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); |
|
|
|
auto unique = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node()}); |
|
|
|
auto tuple_getitem_0 = gen_g.PushBack({gen_g.NewOpInst(prim::kTupleGetItem), unique, CreatInt64Imm(0)}); |
|
|
|
auto tuple_getitem_1 = gen_g.PushBack({gen_g.NewOpInst(prim::kTupleGetItem), unique, CreatInt64Imm(1)}); |
|
|
|
// Use name of tuple_getitem instance in mindspore.ops.functional, not the Primitive name |
|
|
|
const std::string &tuple_getitem_op = "tuple_getitem"; |
|
|
|
auto tuple_getitem_0 = gen_g.PushBack({gen_g.NewOpInst(tuple_getitem_op), unique, CreatInt64Imm(0)}); |
|
|
|
auto tuple_getitem_1 = gen_g.PushBack({gen_g.NewOpInst(tuple_getitem_op), unique, CreatInt64Imm(1)}); |
|
|
|
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), tuple_getitem_1}); |
|
|
|
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); |
|
|
|
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), tuple_getitem_1, cast}); |
|
|
|
|