Browse Source

!952 Simplify the `ZeroLikeFillZero` optimization pass

Merge pull request !952 from thlinh/dev_May6th_improve_zero_fill_like_zero
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
c176bbe4c8
2 changed files with 21 additions and 6 deletions
  1. +1
    -2
      mindspore/ccsrc/optimizer/irpass.cc
  2. +20
    -4
      mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h

+ 1
- 2
mindspore/ccsrc/optimizer/irpass.cc View File

@@ -52,8 +52,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ =
MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM);
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);

// ops eliminate
item_tuple_eliminate_ =


+ 20
- 4
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h View File

@@ -30,6 +30,7 @@
namespace mindspore {
namespace opt {
namespace irpass {

class SpecialOpEliminater {
public:
SpecialOpEliminater()
@@ -156,12 +157,27 @@ class ZeroLikeFillZero : public AnfVisitor {
if (y_ == nullptr || node->func_graph() == nullptr) {
return nullptr;
}
if ((y_->abstract() == nullptr) || !y_->abstract()->isa<abstract::AbstractTensor>()) {
auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
}

abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast<abstract::AbstractTensorPtr>();

TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();

tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
std::memset(data, 0, mem_size);

auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
auto new_cnode = NewValueNode(new_tensor_ptr);
new_cnode->set_abstract(new_tensor_ptr->ToAbstract());

return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
return new_cnode;
}

void Visit(const AnfNodePtr &node) override { y_ = node; }


Loading…
Cancel
Save