Browse Source

if the last input of EnvironGet/EnvironSet CNode is a tuple, then split this CNode to multiple CNode with value as non-tuple

feature/build-system-rewrite
zhousiyi 4 years ago
parent
commit
7d0bee5ba1
5 changed files with 115 additions and 10 deletions
  1. +4
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  2. +2
    -1
      mindspore/ccsrc/frontend/optimizer/irpass.h
  3. +83
    -1
      mindspore/ccsrc/frontend/optimizer/irpass/environ_eliminate.h
  4. +2
    -1
      mindspore/ccsrc/pipeline/jit/pass.cc
  5. +24
    -7
      mindspore/core/utils/symbolic.h

+ 4
- 0
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -142,6 +142,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
MakeSubstitution(std::make_shared<IncorporateEnvironGetSwitchLayer>(), "incorporate_environ_get_switch_layer",
prim::kPrimEnvironGet);

split_environ_get_set_with_tuple_value_ =
MakeSubstitution(std::make_shared<SplitEnvironGetSetWithTupleValue>(), "split_environ_get_set_with_tuple_value",
{prim::kPrimEnvironGet, prim::kPrimEnvironSet});

// Ref eliminate
make_ref_eliminate_ =
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);


+ 2
- 1
mindspore/ccsrc/frontend/optimizer/irpass.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -74,6 +74,7 @@ class OptimizeIRPassLib {
SubstitutionPtr incorporate_environ_get_bypass_recursive_;
SubstitutionPtr incorporate_environ_get_switch_;
SubstitutionPtr incorporate_environ_get_switch_layer_;
SubstitutionPtr split_environ_get_set_with_tuple_value_;

// Ref eliminate
SubstitutionPtr make_ref_eliminate_;


+ 83
- 1
mindspore/ccsrc/frontend/optimizer/irpass/environ_eliminate.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -161,6 +161,23 @@ class EnvironGetTransformACrossGraph {
mindspore::HashMap<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>>
cache_;
};

AnfNodePtr GetIndexedEnvironValueNode(const FuncGraphPtr &fg, const AnfNodePtr &origin_value_node,
const std::size_t index) {
AnfNodePtr new_value_node;
if (IsValueNode<ValueTuple>(origin_value_node)) {
auto origin_value_tuple = GetValueNode<ValueTuplePtr>(origin_value_node);
if (index >= origin_value_tuple->size()) {
MS_LOG(EXCEPTION) << "Index: " << index << " is greater than Value size: " << origin_value_tuple->size()
<< ", Default Value: " << origin_value_node->ToString();
}
new_value_node = NewValueNode((*origin_value_tuple)[index]);
} else {
new_value_node = fg->NewCNode(
{NewValueNode(prim::kPrimTupleGetItem), origin_value_node, NewValueNode(MakeValue(static_cast<int64_t>(index)))});
}
return new_value_node;
}
} // namespace internal

// {prim::kPrimEnvironGet, C1, C2, Y} -> Y
@@ -515,6 +532,71 @@ class IncorporateEnvironGetSwitchLayer : public AnfVisitor {
bool is_match_{false};
internal::EnvironGetTransformACrossGraph environ_get_transform_;
};

// {prim::kPrimEnvironSet, E, K, V} ->
// E1 = {prim::kPrimEnvironSet, E, K1, V1},
// E2 = {prim::kPrimEnvironSet, E1, K2, V2},
// ...
// {prim::kPrimEnvironGet, E, K, V} ->
// v1 = {prim::kPrimEnvironGet, E, K1, default_v1},
// v2 = {prim::kPrimEnvironGet, E, K2, devault_v2},
// ...
// v_tuple = {prim::kPrimMakeTuple, v1, v2, ...}
class SplitEnvironGetSetWithTupleValue : public AnfVisitor {
public:
SplitEnvironGetSetWithTupleValue() = default;
~SplitEnvironGetSetWithTupleValue() override = default;

AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!(IsPrimitiveCNode(node, prim::kPrimEnvironSet) || IsPrimitiveCNode(node, prim::kPrimEnvironGet))) {
return nullptr;
}
// {prim::kPrimEnvironSet, E, key, node_with_abstract_is_tuple} or
// {prim::kPrimEnvironGet, E, key, node_with_abstract_is_tuple}
const auto &cnode = node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
auto &environ_node = inputs[internal::kEnvironOffset];
const auto &origin_value_node = inputs[internal::kValueOffset];
const auto &origin_key_node = GetValueNode<SymbolicKeyInstancePtr>(inputs[internal::kSymbolicKeyOffset]);

if (origin_key_node == nullptr || origin_value_node->abstract() == nullptr ||
!origin_value_node->abstract()->isa<abstract::AbstractTuple>()) {
return nullptr;
}

const auto &origin_value_abstract = origin_value_node->abstract()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(origin_value_abstract);

AnfNodePtr prev_environ_node = environ_node;
auto fg = node->func_graph();

if (IsPrimitiveCNode(node, prim::kPrimEnvironSet)) {
CNodePtr new_cnode = cnode;
// Cascade the split CNode of EnvironSet.
for (std::size_t index = 0; index < origin_value_abstract->elements().size(); ++index) {
auto new_key = std::make_shared<SymbolicKeyInstance>(
origin_key_node->node(), origin_value_abstract->elements()[index], static_cast<int64_t>(index));
AnfNodePtr new_value_node = internal::GetIndexedEnvironValueNode(fg, origin_value_node, index);

new_cnode = fg->NewCNode({inputs[0], prev_environ_node, NewValueNode(new_key), new_value_node});
prev_environ_node = new_cnode;
}
return new_cnode;
} else {
// MakeTuple the split CNode of EnvironGet.
AnfNodePtrList tuple_item_list{NewValueNode(prim::kPrimMakeTuple)};
for (std::size_t index = 0; index < origin_value_abstract->elements().size(); ++index) {
auto new_key = std::make_shared<SymbolicKeyInstance>(
origin_key_node->node(), origin_value_abstract->elements()[index], static_cast<int64_t>(index));
AnfNodePtr new_value_node = internal::GetIndexedEnvironValueNode(fg, origin_value_node, index);
auto new_item_cnode = fg->NewCNode({inputs[0], environ_node, NewValueNode(new_key), new_value_node});
tuple_item_list.push_back(new_item_cnode);
}
auto new_cnode = fg->NewCNode(tuple_item_list);
return new_cnode;
}
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore


+ 2
- 1
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -332,6 +332,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.row_tensor_add_zeros_like_,
irpass.mini_step_allgather_replace_,
irpass.micro_step_allgather_replace_,
irpass.split_environ_get_set_with_tuple_value_,
},
false, true);
opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_});


+ 24
- 7
mindspore/core/utils/symbolic.h View File

@@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -33,17 +33,21 @@
namespace mindspore {
class SymbolicKeyInstance : public Value {
public:
SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract)
: node_(node), abstract_(abstract) {}
SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract, const int64_t index = -1)
: node_(node), abstract_(abstract), index_(index) {}
~SymbolicKeyInstance() override = default;
MS_DECLARE_PARENT(SymbolicKeyInstance, Value);
AnfNodePtr node() const { return node_; }
abstract::AbstractBasePtr abstract() const { return abstract_; }
bool operator==(const SymbolicKeyInstance &other) const {
return (*node_ == *other.node_) && (*abstract_ == *other.abstract_);
return (*node_ == *other.node_) && (*abstract_ == *other.abstract_) && (index_ == other.index_);
}

std::size_t hash() const override {
auto hash_value = hash_combine(std::hash<AnfNodePtr>{}(node_), std::hash<int64_t>{}(index_));
return hash_value;
}

std::size_t hash() const override { return std::hash<AnfNodePtr>{}(node_); }
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<SymbolicKeyInstance> &inst) {
if (inst == nullptr) {
os << "[Key]["
@@ -55,8 +59,17 @@ class SymbolicKeyInstance : public Value {
return os;
}
std::string ToString() const override {
return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString();
std::ostringstream oss;
if (node_ == nullptr) {
return "Invalid node";
}
oss << "[Key][" << node_->type_name() + "]" << node_->ToString();
if (index_ != -1) {
oss << "[" << index_ << "]";
}
return oss.str();
}

bool operator==(const Value &other) const override {
if (other.isa<SymbolicKeyInstance>()) {
auto other_ = static_cast<const SymbolicKeyInstance &>(other);
@@ -73,6 +86,10 @@ class SymbolicKeyInstance : public Value {
private:
AnfNodePtr node_;
abstract::AbstractBasePtr abstract_;
// If the Value in EnvironGet/EnvironSet of one SymbolicKey is Tuple, that SymbolicKey will be split
// to multiple SymbolicKey, this index is used to discriminate those SymbolicKey derived from the same
// one.
int64_t index_{-1};
};

using SymbolicKeyInstancePtr = std::shared_ptr<SymbolicKeyInstance>;
@@ -95,7 +112,7 @@ struct SymbolicKeyInstanceEqual {
MS_EXCEPTION_IF_NULL(rhs->node());
MS_EXCEPTION_IF_NULL(lhs->abstract());
MS_EXCEPTION_IF_NULL(rhs->abstract());
return (*lhs->node() == *rhs->node()) && (*lhs->abstract() == *rhs->abstract());
return *lhs == *rhs;
}
};



Loading…
Cancel
Save