From 7d0bee5ba18dc6663e1f285652dbc19e0f2d4ffc Mon Sep 17 00:00:00 2001 From: zhousiyi Date: Thu, 20 Jan 2022 03:15:07 +0000 Subject: [PATCH] if the last input of EnvironGet/EnvironSet CNode is a tuple, then split this CNode to multiple CNode with value as non-tuple --- mindspore/ccsrc/frontend/optimizer/irpass.cc | 4 + mindspore/ccsrc/frontend/optimizer/irpass.h | 3 +- .../optimizer/irpass/environ_eliminate.h | 84 ++++++++++++++++++- mindspore/ccsrc/pipeline/jit/pass.cc | 3 +- mindspore/core/utils/symbolic.h | 31 +++++-- 5 files changed, 115 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 4062873e0b..00607cfb45 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -142,6 +142,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { MakeSubstitution(std::make_shared(), "incorporate_environ_get_switch_layer", prim::kPrimEnvironGet); + split_environ_get_set_with_tuple_value_ = + MakeSubstitution(std::make_shared(), "split_environ_get_set_with_tuple_value", + {prim::kPrimEnvironGet, prim::kPrimEnvironSet}); + // Ref eliminate make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 88c19a641a..5fa5655d1c 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,6 +74,7 @@ class OptimizeIRPassLib { SubstitutionPtr incorporate_environ_get_bypass_recursive_; SubstitutionPtr incorporate_environ_get_switch_; SubstitutionPtr incorporate_environ_get_switch_layer_; + SubstitutionPtr split_environ_get_set_with_tuple_value_; // Ref eliminate SubstitutionPtr make_ref_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/environ_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/environ_eliminate.h index 799362a88f..ece91dc6cf 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/environ_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/environ_eliminate.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -161,6 +161,23 @@ class EnvironGetTransformACrossGraph { mindspore::HashMap, FuncGraphPtr, PairHasher>> cache_; }; + +AnfNodePtr GetIndexedEnvironValueNode(const FuncGraphPtr &fg, const AnfNodePtr &origin_value_node, + const std::size_t index) { + AnfNodePtr new_value_node; + if (IsValueNode(origin_value_node)) { + auto origin_value_tuple = GetValueNode(origin_value_node); + if (index >= origin_value_tuple->size()) { + MS_LOG(EXCEPTION) << "Index: " << index << " is greater than Value size: " << origin_value_tuple->size() + << ", Default Value: " << origin_value_node->ToString(); + } + new_value_node = NewValueNode((*origin_value_tuple)[index]); + } else { + new_value_node = fg->NewCNode( + {NewValueNode(prim::kPrimTupleGetItem), origin_value_node, NewValueNode(MakeValue(static_cast(index)))}); + } + return new_value_node; +} } // namespace internal // {prim::kPrimEnvironGet, C1, C2, Y} -> Y @@ -515,6 +532,71 @@ class IncorporateEnvironGetSwitchLayer : public AnfVisitor { bool is_match_{false}; internal::EnvironGetTransformACrossGraph environ_get_transform_; }; + +// {prim::kPrimEnvironSet, E, K, V} -> +// E1 = {prim::kPrimEnvironSet, E, K1, V1}, +// E2 = {prim::kPrimEnvironSet, E1, K2, V2}, +// ... +// {prim::kPrimEnvironGet, E, K, V} -> +// v1 = {prim::kPrimEnvironGet, E, K1, default_v1}, +// v2 = {prim::kPrimEnvironGet, E, K2, devault_v2}, +// ... +// v_tuple = {prim::kPrimMakeTuple, v1, v2, ...} +class SplitEnvironGetSetWithTupleValue : public AnfVisitor { + public: + SplitEnvironGetSetWithTupleValue() = default; + ~SplitEnvironGetSetWithTupleValue() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!(IsPrimitiveCNode(node, prim::kPrimEnvironSet) || IsPrimitiveCNode(node, prim::kPrimEnvironGet))) { + return nullptr; + } + // {prim::kPrimEnvironSet, E, key, node_with_abstract_is_tuple} or + // {prim::kPrimEnvironGet, E, key, node_with_abstract_is_tuple} + const auto &cnode = node->cast(); + const auto &inputs = cnode->inputs(); + auto &environ_node = inputs[internal::kEnvironOffset]; + const auto &origin_value_node = inputs[internal::kValueOffset]; + const auto &origin_key_node = GetValueNode(inputs[internal::kSymbolicKeyOffset]); + + if (origin_key_node == nullptr || origin_value_node->abstract() == nullptr || + !origin_value_node->abstract()->isa()) { + return nullptr; + } + + const auto &origin_value_abstract = origin_value_node->abstract()->cast(); + MS_EXCEPTION_IF_NULL(origin_value_abstract); + + AnfNodePtr prev_environ_node = environ_node; + auto fg = node->func_graph(); + + if (IsPrimitiveCNode(node, prim::kPrimEnvironSet)) { + CNodePtr new_cnode = cnode; + // Cascade the split CNode of EnvironSet. + for (std::size_t index = 0; index < origin_value_abstract->elements().size(); ++index) { + auto new_key = std::make_shared( + origin_key_node->node(), origin_value_abstract->elements()[index], static_cast(index)); + AnfNodePtr new_value_node = internal::GetIndexedEnvironValueNode(fg, origin_value_node, index); + + new_cnode = fg->NewCNode({inputs[0], prev_environ_node, NewValueNode(new_key), new_value_node}); + prev_environ_node = new_cnode; + } + return new_cnode; + } else { + // MakeTuple the split CNode of EnvironGet. + AnfNodePtrList tuple_item_list{NewValueNode(prim::kPrimMakeTuple)}; + for (std::size_t index = 0; index < origin_value_abstract->elements().size(); ++index) { + auto new_key = std::make_shared( + origin_key_node->node(), origin_value_abstract->elements()[index], static_cast(index)); + AnfNodePtr new_value_node = internal::GetIndexedEnvironValueNode(fg, origin_value_node, index); + auto new_item_cnode = fg->NewCNode({inputs[0], environ_node, NewValueNode(new_key), new_value_node}); + tuple_item_list.push_back(new_item_cnode); + } + auto new_cnode = fg->NewCNode(tuple_item_list); + return new_cnode; + } + } +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 39624b6fea..17e0c8c1fb 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -332,6 +332,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.row_tensor_add_zeros_like_, irpass.mini_step_allgather_replace_, irpass.micro_step_allgather_replace_, + irpass.split_environ_get_set_with_tuple_value_, }, false, true); opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_}); diff --git a/mindspore/core/utils/symbolic.h b/mindspore/core/utils/symbolic.h index c26ca31772..34ac627736 100644 --- a/mindspore/core/utils/symbolic.h +++ b/mindspore/core/utils/symbolic.h @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,17 +33,21 @@ namespace mindspore { class SymbolicKeyInstance : public Value { public: - SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract) - : node_(node), abstract_(abstract) {} + SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract, const int64_t index = -1) + : node_(node), abstract_(abstract), index_(index) {} ~SymbolicKeyInstance() override = default; MS_DECLARE_PARENT(SymbolicKeyInstance, Value); AnfNodePtr node() const { return node_; } abstract::AbstractBasePtr abstract() const { return abstract_; } bool operator==(const SymbolicKeyInstance &other) const { - return (*node_ == *other.node_) && (*abstract_ == *other.abstract_); + return (*node_ == *other.node_) && (*abstract_ == *other.abstract_) && (index_ == other.index_); + } + + std::size_t hash() const override { + auto hash_value = hash_combine(std::hash{}(node_), std::hash{}(index_)); + return hash_value; } - std::size_t hash() const override { return std::hash{}(node_); } friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &inst) { if (inst == nullptr) { os << "[Key][" @@ -55,8 +59,17 @@ class SymbolicKeyInstance : public Value { return os; } std::string ToString() const override { - return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString(); + std::ostringstream oss; + if (node_ == nullptr) { + return "Invalid node"; + } + oss << "[Key][" << node_->type_name() + "]" << node_->ToString(); + if (index_ != -1) { + oss << "[" << index_ << "]"; + } + return oss.str(); } + bool operator==(const Value &other) const override { if (other.isa()) { auto other_ = static_cast(other); @@ -73,6 +86,10 @@ class SymbolicKeyInstance : public Value { private: AnfNodePtr node_; abstract::AbstractBasePtr abstract_; + // If the Value in EnvironGet/EnvironSet of one SymbolicKey is Tuple, that SymbolicKey will be split + // to multiple SymbolicKey, this index is used to discriminate those SymbolicKey derived from the same + // one. + int64_t index_{-1}; }; using SymbolicKeyInstancePtr = std::shared_ptr; @@ -95,7 +112,7 @@ struct SymbolicKeyInstanceEqual { MS_EXCEPTION_IF_NULL(rhs->node()); MS_EXCEPTION_IF_NULL(lhs->abstract()); MS_EXCEPTION_IF_NULL(rhs->abstract()); - return (*lhs->node() == *rhs->node()) && (*lhs->abstract() == *rhs->abstract()); + return *lhs == *rhs; } };