From d82bbe34c80777fbf220c825fb01f46f485ea623 Mon Sep 17 00:00:00 2001 From: Hoai Linh Tran h00472437 Date: Wed, 6 May 2020 11:52:35 -0400 Subject: [PATCH] Added improvement for ZeroLikeFillZero optimization pass: The old algorithm convert the PrimitivePy op (with 3 nodes) into a new subtree with 9 nodes and after that a Renormalize pass is needed to simplified it back to a tensor. The new algorithm will create the tensor while visiting in the pass, therefore only a single node is created and no Renormalize will be needed for this pass (if other passes requires then Renormalize will still be called but no further infer is needed for the created tensor) Signed-off-by: Hoai Linh Tran h00472437 Code review --- mindspore/ccsrc/optimizer/irpass.cc | 3 +-- .../optimizer/irpass/special_op_eliminate.h | 24 +++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 2bd013cb08..8070874066 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -52,8 +52,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); - zero_like_fill_zero_ = - MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM); + zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); // ops eliminate item_tuple_eliminate_ = diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index e06ccd862b..00dcbc67b4 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -30,6 +30,7 @@ namespace mindspore { namespace opt { namespace irpass { + class SpecialOpEliminater { public: SpecialOpEliminater() @@ -156,12 +157,27 @@ class ZeroLikeFillZero : public AnfVisitor { if (y_ == nullptr || node->func_graph() == nullptr) { return nullptr; } + if ((y_->abstract() == nullptr) || !y_->abstract()->isa()) { + auto fg = node->func_graph(); + auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); + auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); + return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))}); + } + + abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast(); + + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + tensor::TensorPtr new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c(true)); + std::memset(data, 0, mem_size); - auto fg = node->func_graph(); - auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); - auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); + auto new_cnode = NewValueNode(new_tensor_ptr); + new_cnode->set_abstract(new_tensor_ptr->ToAbstract()); - return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))}); + return new_cnode; } void Visit(const AnfNodePtr &node) override { y_ = node; }