Browse Source

!952 Simplify the `ZeroLikeFillZero` optimization pass

Merge pull request !952 from thlinh/dev_May6th_improve_zero_fill_like_zero
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
c176bbe4c8
2 changed files with 21 additions and 6 deletions
  1. +1
    -2
      mindspore/ccsrc/optimizer/irpass.cc
  2. +20
    -4
      mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h

+ 1
- 2
mindspore/ccsrc/optimizer/irpass.cc View File

@@ -52,8 +52,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ =
MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM);
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);


// ops eliminate // ops eliminate
item_tuple_eliminate_ = item_tuple_eliminate_ =


+ 20
- 4
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h View File

@@ -30,6 +30,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {

class SpecialOpEliminater { class SpecialOpEliminater {
public: public:
SpecialOpEliminater() SpecialOpEliminater()
@@ -156,12 +157,27 @@ class ZeroLikeFillZero : public AnfVisitor {
if (y_ == nullptr || node->func_graph() == nullptr) { if (y_ == nullptr || node->func_graph() == nullptr) {
return nullptr; return nullptr;
} }
if ((y_->abstract() == nullptr) || !y_->abstract()->isa<abstract::AbstractTensor>()) {
auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
}

abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast<abstract::AbstractTensorPtr>();

TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();

tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
std::memset(data, 0, mem_size);


auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
auto new_cnode = NewValueNode(new_tensor_ptr);
new_cnode->set_abstract(new_tensor_ptr->ToAbstract());


return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
return new_cnode;
} }


void Visit(const AnfNodePtr &node) override { y_ = node; } void Visit(const AnfNodePtr &node) override { y_ = node; }


Loading…
Cancel
Save