diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 2bd013cb08..8070874066 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -52,8 +52,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); - zero_like_fill_zero_ = - MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM); + zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); // ops eliminate item_tuple_eliminate_ = diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index e06ccd862b..00dcbc67b4 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -30,6 +30,7 @@ namespace mindspore { namespace opt { namespace irpass { + class SpecialOpEliminater { public: SpecialOpEliminater() @@ -156,12 +157,27 @@ class ZeroLikeFillZero : public AnfVisitor { if (y_ == nullptr || node->func_graph() == nullptr) { return nullptr; } + if ((y_->abstract() == nullptr) || !y_->abstract()->isa()) { + auto fg = node->func_graph(); + auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); + auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); + return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))}); + } + + abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast(); + + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + tensor::TensorPtr new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c(true)); + std::memset(data, 0, mem_size); - auto fg = node->func_graph(); - auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); - auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); + auto new_cnode = NewValueNode(new_tensor_ptr); + new_cnode->set_abstract(new_tensor_ptr->ToAbstract()); - return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))}); + return new_cnode; } void Visit(const AnfNodePtr &node) override { y_ = node; }