| @@ -142,6 +142,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| MakeSubstitution(std::make_shared<IncorporateEnvironGetSwitchLayer>(), "incorporate_environ_get_switch_layer", | |||
| prim::kPrimEnvironGet); | |||
| split_environ_get_set_with_tuple_value_ = | |||
| MakeSubstitution(std::make_shared<SplitEnvironGetSetWithTupleValue>(), "split_environ_get_set_with_tuple_value", | |||
| {prim::kPrimEnvironGet, prim::kPrimEnvironSet}); | |||
| // Ref eliminate | |||
| make_ref_eliminate_ = | |||
| MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -74,6 +74,7 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr incorporate_environ_get_bypass_recursive_; | |||
| SubstitutionPtr incorporate_environ_get_switch_; | |||
| SubstitutionPtr incorporate_environ_get_switch_layer_; | |||
| SubstitutionPtr split_environ_get_set_with_tuple_value_; | |||
| // Ref eliminate | |||
| SubstitutionPtr make_ref_eliminate_; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -161,6 +161,23 @@ class EnvironGetTransformACrossGraph { | |||
| mindspore::HashMap<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>> | |||
| cache_; | |||
| }; | |||
| AnfNodePtr GetIndexedEnvironValueNode(const FuncGraphPtr &fg, const AnfNodePtr &origin_value_node, | |||
| const std::size_t index) { | |||
| AnfNodePtr new_value_node; | |||
| if (IsValueNode<ValueTuple>(origin_value_node)) { | |||
| auto origin_value_tuple = GetValueNode<ValueTuplePtr>(origin_value_node); | |||
| if (index >= origin_value_tuple->size()) { | |||
| MS_LOG(EXCEPTION) << "Index: " << index << " is greater than Value size: " << origin_value_tuple->size() | |||
| << ", Default Value: " << origin_value_node->ToString(); | |||
| } | |||
| new_value_node = NewValueNode((*origin_value_tuple)[index]); | |||
| } else { | |||
| new_value_node = fg->NewCNode( | |||
| {NewValueNode(prim::kPrimTupleGetItem), origin_value_node, NewValueNode(MakeValue(static_cast<int64_t>(index)))}); | |||
| } | |||
| return new_value_node; | |||
| } | |||
| } // namespace internal | |||
| // {prim::kPrimEnvironGet, C1, C2, Y} -> Y | |||
| @@ -515,6 +532,71 @@ class IncorporateEnvironGetSwitchLayer : public AnfVisitor { | |||
| bool is_match_{false}; | |||
| internal::EnvironGetTransformACrossGraph environ_get_transform_; | |||
| }; | |||
| // {prim::kPrimEnvironSet, E, K, V} -> | |||
| // E1 = {prim::kPrimEnvironSet, E, K1, V1}, | |||
| // E2 = {prim::kPrimEnvironSet, E1, K2, V2}, | |||
| // ... | |||
| // {prim::kPrimEnvironGet, E, K, V} -> | |||
| // v1 = {prim::kPrimEnvironGet, E, K1, default_v1}, | |||
| // v2 = {prim::kPrimEnvironGet, E, K2, devault_v2}, | |||
| // ... | |||
| // v_tuple = {prim::kPrimMakeTuple, v1, v2, ...} | |||
| class SplitEnvironGetSetWithTupleValue : public AnfVisitor { | |||
| public: | |||
| SplitEnvironGetSetWithTupleValue() = default; | |||
| ~SplitEnvironGetSetWithTupleValue() override = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!(IsPrimitiveCNode(node, prim::kPrimEnvironSet) || IsPrimitiveCNode(node, prim::kPrimEnvironGet))) { | |||
| return nullptr; | |||
| } | |||
| // {prim::kPrimEnvironSet, E, key, node_with_abstract_is_tuple} or | |||
| // {prim::kPrimEnvironGet, E, key, node_with_abstract_is_tuple} | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| const auto &inputs = cnode->inputs(); | |||
| auto &environ_node = inputs[internal::kEnvironOffset]; | |||
| const auto &origin_value_node = inputs[internal::kValueOffset]; | |||
| const auto &origin_key_node = GetValueNode<SymbolicKeyInstancePtr>(inputs[internal::kSymbolicKeyOffset]); | |||
| if (origin_key_node == nullptr || origin_value_node->abstract() == nullptr || | |||
| !origin_value_node->abstract()->isa<abstract::AbstractTuple>()) { | |||
| return nullptr; | |||
| } | |||
| const auto &origin_value_abstract = origin_value_node->abstract()->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(origin_value_abstract); | |||
| AnfNodePtr prev_environ_node = environ_node; | |||
| auto fg = node->func_graph(); | |||
| if (IsPrimitiveCNode(node, prim::kPrimEnvironSet)) { | |||
| CNodePtr new_cnode = cnode; | |||
| // Cascade the split CNode of EnvironSet. | |||
| for (std::size_t index = 0; index < origin_value_abstract->elements().size(); ++index) { | |||
| auto new_key = std::make_shared<SymbolicKeyInstance>( | |||
| origin_key_node->node(), origin_value_abstract->elements()[index], static_cast<int64_t>(index)); | |||
| AnfNodePtr new_value_node = internal::GetIndexedEnvironValueNode(fg, origin_value_node, index); | |||
| new_cnode = fg->NewCNode({inputs[0], prev_environ_node, NewValueNode(new_key), new_value_node}); | |||
| prev_environ_node = new_cnode; | |||
| } | |||
| return new_cnode; | |||
| } else { | |||
| // MakeTuple the split CNode of EnvironGet. | |||
| AnfNodePtrList tuple_item_list{NewValueNode(prim::kPrimMakeTuple)}; | |||
| for (std::size_t index = 0; index < origin_value_abstract->elements().size(); ++index) { | |||
| auto new_key = std::make_shared<SymbolicKeyInstance>( | |||
| origin_key_node->node(), origin_value_abstract->elements()[index], static_cast<int64_t>(index)); | |||
| AnfNodePtr new_value_node = internal::GetIndexedEnvironValueNode(fg, origin_value_node, index); | |||
| auto new_item_cnode = fg->NewCNode({inputs[0], environ_node, NewValueNode(new_key), new_value_node}); | |||
| tuple_item_list.push_back(new_item_cnode); | |||
| } | |||
| auto new_cnode = fg->NewCNode(tuple_item_list); | |||
| return new_cnode; | |||
| } | |||
| } | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -332,6 +332,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.row_tensor_add_zeros_like_, | |||
| irpass.mini_step_allgather_replace_, | |||
| irpass.micro_step_allgather_replace_, | |||
| irpass.split_environ_get_set_with_tuple_value_, | |||
| }, | |||
| false, true); | |||
| opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_}); | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -33,17 +33,21 @@ | |||
| namespace mindspore { | |||
| class SymbolicKeyInstance : public Value { | |||
| public: | |||
| SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract) | |||
| : node_(node), abstract_(abstract) {} | |||
| SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract, const int64_t index = -1) | |||
| : node_(node), abstract_(abstract), index_(index) {} | |||
| ~SymbolicKeyInstance() override = default; | |||
| MS_DECLARE_PARENT(SymbolicKeyInstance, Value); | |||
| AnfNodePtr node() const { return node_; } | |||
| abstract::AbstractBasePtr abstract() const { return abstract_; } | |||
| bool operator==(const SymbolicKeyInstance &other) const { | |||
| return (*node_ == *other.node_) && (*abstract_ == *other.abstract_); | |||
| return (*node_ == *other.node_) && (*abstract_ == *other.abstract_) && (index_ == other.index_); | |||
| } | |||
| std::size_t hash() const override { | |||
| auto hash_value = hash_combine(std::hash<AnfNodePtr>{}(node_), std::hash<int64_t>{}(index_)); | |||
| return hash_value; | |||
| } | |||
| std::size_t hash() const override { return std::hash<AnfNodePtr>{}(node_); } | |||
| friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<SymbolicKeyInstance> &inst) { | |||
| if (inst == nullptr) { | |||
| os << "[Key][" | |||
| @@ -55,8 +59,17 @@ class SymbolicKeyInstance : public Value { | |||
| return os; | |||
| } | |||
| std::string ToString() const override { | |||
| return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString(); | |||
| std::ostringstream oss; | |||
| if (node_ == nullptr) { | |||
| return "Invalid node"; | |||
| } | |||
| oss << "[Key][" << node_->type_name() + "]" << node_->ToString(); | |||
| if (index_ != -1) { | |||
| oss << "[" << index_ << "]"; | |||
| } | |||
| return oss.str(); | |||
| } | |||
| bool operator==(const Value &other) const override { | |||
| if (other.isa<SymbolicKeyInstance>()) { | |||
| auto other_ = static_cast<const SymbolicKeyInstance &>(other); | |||
| @@ -73,6 +86,10 @@ class SymbolicKeyInstance : public Value { | |||
| private: | |||
| AnfNodePtr node_; | |||
| abstract::AbstractBasePtr abstract_; | |||
| // If the Value in EnvironGet/EnvironSet of one SymbolicKey is Tuple, that SymbolicKey will be split | |||
| // to multiple SymbolicKey, this index is used to discriminate those SymbolicKey derived from the same | |||
| // one. | |||
| int64_t index_{-1}; | |||
| }; | |||
| using SymbolicKeyInstancePtr = std::shared_ptr<SymbolicKeyInstance>; | |||
| @@ -95,7 +112,7 @@ struct SymbolicKeyInstanceEqual { | |||
| MS_EXCEPTION_IF_NULL(rhs->node()); | |||
| MS_EXCEPTION_IF_NULL(lhs->abstract()); | |||
| MS_EXCEPTION_IF_NULL(rhs->abstract()); | |||
| return (*lhs->node() == *rhs->node()) && (*lhs->abstract() == *rhs->abstract()); | |||
| return *lhs == *rhs; | |||
| } | |||
| }; | |||