Browse Source

Added improvement for ZeroLikeFillZero optimization pass: The old algorithm convert the PrimitivePy op (with 3 nodes) into a

new subtree with 9 nodes and after that a Renormalize pass is needed to simplified it back to a tensor. The new algorithm will
create the tensor while visiting in the pass, therefore only a single node is created and no Renormalize will be needed for this pass
(if other passes requires then Renormalize will still be called but no further infer is needed for the created tensor)

Signed-off-by: Hoai Linh Tran h00472437 <hoai.linh.tran@huawei.com>

Code review
tags/v0.3.0-alpha
Hoai Linh Tran h00472437 5 years ago
parent
commit
d82bbe34c8
2 changed files with 21 additions and 6 deletions
  1. +1
    -2
      mindspore/ccsrc/optimizer/irpass.cc
  2. +20
    -4
      mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h

+ 1
- 2
mindspore/ccsrc/optimizer/irpass.cc View File

@@ -52,8 +52,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ =
MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM);
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);

// ops eliminate
item_tuple_eliminate_ =


+ 20
- 4
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h View File

@@ -30,6 +30,7 @@
namespace mindspore {
namespace opt {
namespace irpass {

class SpecialOpEliminater {
public:
SpecialOpEliminater()
@@ -156,12 +157,27 @@ class ZeroLikeFillZero : public AnfVisitor {
if (y_ == nullptr || node->func_graph() == nullptr) {
return nullptr;
}
if ((y_->abstract() == nullptr) || !y_->abstract()->isa<abstract::AbstractTensor>()) {
auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
}

abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast<abstract::AbstractTensorPtr>();

TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();

tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
std::memset(data, 0, mem_size);

auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
auto new_cnode = NewValueNode(new_tensor_ptr);
new_cnode->set_abstract(new_tensor_ptr->ToAbstract());

return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
return new_cnode;
}

void Visit(const AnfNodePtr &node) override { y_ = node; }


Loading…
Cancel
Save