Browse Source

remove CastInfo from CNODE

tags/v1.6.0
Xiaoda Zhang 4 years ago
parent
commit
66c7474e5a
6 changed files with 41 additions and 11 deletions
  1. +4
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc
  2. +1
    -4
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc
  3. +6
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
  4. +27
    -1
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
  5. +2
    -0
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.h
  6. +1
    -3
      tests/ut/python/parallel/test_auto_parallel_cast.py

+ 4
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc View File

@@ -529,8 +529,11 @@ bool Edge::CheckStrategyConsistency(StrategyPtr prev_stra, StrategyPtr next_stra
"different sharding strategies. These operators are: ";
auto const &succ_edges = prev_op_->succ_edges();
for (auto const &succ_edge : succ_edges) {
if (succ_edge->next_operator()->cnodes().empty()) {
MS_LOG(EXCEPTION) << "No CNODE info has been set in operator: " << succ_edge->next_operator()->name();
}
MS_LOG(ERROR) << succ_edge->next_operator()->name() << ", the corresponding fullname is: "
<< succ_edge->next_operator()->cnode()->fullname_with_scope();
<< succ_edge->next_operator()->cnodes()[0]->fullname_with_scope();
}
MS_LOG(EXCEPTION) << "Configure these operators with consistent sharding strategies.";
}


+ 1
- 4
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc View File

@@ -1770,10 +1770,7 @@ void OperatorInfo::set_swc_index(int64_t swc, int64_t depth) {
swc_index_ = swc;
}

CNodePtr OperatorInfo::cnode() {
MS_EXCEPTION_IF_NULL(cnode_);
return cnode_;
}
std::vector<CNodePtr> OperatorInfo::cnodes() { return cnodes_; }

double OperatorInfo::GetForwardMemoryCostFromCNode() {
return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);


+ 6
- 2
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h View File

@@ -167,8 +167,11 @@ class OperatorInfo {
void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; }
const std::vector<ValuePtr> &input_value() const { return input_value_; }
void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; }
void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
CNodePtr cnode();
void set_cnode(const CNodePtr &cnode) {
cnode_ = cnode;
cnodes_.push_back(cnode);
}
std::vector<CNodePtr> cnodes();
bool is_alive() const { return is_alive_; }
void SetNotAlive() { is_alive_ = false; }
StrategyPtr strategy() const { return strategy_; }
@@ -299,6 +302,7 @@ class OperatorInfo {
std::vector<bool> split_flag_list_;
std::string refkey_parameter_name_;
CNodePtr cnode_;
std::vector<CNodePtr> cnodes_;
int64_t used_devices_ = -1;
// the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default
bool repeated_num_in_dev_matrix_right_ = true;


+ 27
- 1
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc View File

@@ -233,6 +233,7 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn

// 'configured_stra_ops_' includes all operators that are configured sharding strategies.
std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> configured_stra_ops_;
std::set<OperatorInfoPtr> ignore_candidate_;
void InitCostGraph() {
if (entire_costgraph == nullptr) {
entire_costgraph = std::make_shared<CostGraph>();
@@ -241,6 +242,7 @@ void InitCostGraph() {
CostModelContext::GetInstance()->PrintCostModel();
entire_costgraph->Init();
configured_stra_ops_.clear();
ignore_candidate_.clear();
}

void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const PrimitivePtr &prim,
@@ -297,6 +299,13 @@ void ApplyApproximationForNode(const OperatorInfoPtr &operator_info) {
}
}

void AddOperatorToIgnoreCandidates(const PrimitivePtr &prim, const OperatorInfoPtr &operator_info) {
if (prim->name() == CAST) {
// add CAST into ignore_candidate
ignore_candidate_.insert(operator_info);
}
}

OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
StrategyMap *stra_map) {
MS_EXCEPTION_IF_NULL(prim);
@@ -344,6 +353,8 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info->set_input_value(input_value);
operator_info->set_outputs_dtype(cnode->Type());
operator_info->set_cnode(cnode);

AddOperatorToIgnoreCandidates(prim, operator_info);
// key of strategy map
std::string strategy_key_name = "";
auto param_names = NodeParameterName(cnode, -1, 0);
@@ -968,6 +979,17 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
}
}

Status IgnoreOperatorsInCostGraph() {
for (const auto &op : ignore_candidate_) {
auto cnodes = op->cnodes();
for (auto &cnode : cnodes) {
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_user_data<OperatorInfo>(nullptr);
}
}
return SUCCESS;
}

Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
// There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
// Step 1: Traverse the ANF graph, and create NODEs for costgraph:
@@ -1035,7 +1057,6 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
(ParallelContext::GetInstance()->sharding_propagation());
if (use_sp) {
entire_costgraph->StrategyPropagate(configured_stra_ops_);
configured_stra_ops_.clear();
} else if (GetStrategy(entire_costgraph) != SUCCESS) {
MS_LOG(ERROR) << "Strategy search for cost-graph fails";
return FAILED;
@@ -1054,7 +1075,12 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
MS_LOG(INFO) << op->name() << " : The strategy is:";
PrintStrategy(s_strategy);
}
// Remove some operatorInfo from the CNODEs
IgnoreOperatorsInCostGraph();

ops_in_a_loop_.clear();
configured_stra_ops_.clear();
ignore_candidate_.clear();

return SUCCESS;
}


+ 2
- 0
mindspore/ccsrc/frontend/parallel/step_auto_parallel.h View File

@@ -47,6 +47,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes);

void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes);

Status IgnoreOperatorsInCostGraph();

Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);

Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);


+ 1
- 3
tests/ut/python/parallel/test_auto_parallel_cast.py View File

@@ -84,9 +84,7 @@ def test_double_star_graph():
net.set_train()
_cell_graph_executor.compile(net, x, y, z, w, phase='train')
strategies = _cell_graph_executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/Cast-op1': [[8, 1]],
'Default/network-Net/Cast-op3': [[1, 8]],
'Default/network-Net/MatMul-op2': [[8, 1], [1, 1]],
expected_strategies = {'Default/network-Net/MatMul-op2': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op0': [[1, 8], [8, 1]]}
assert strategies == expected_strategies

Loading…
Cancel
Save