Browse Source

!2030 [AutoParallel] use replacement instead of recreation for edges in rec prog parse

Merge pull request !2030 from Chong/ReID
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
ff0590315c
5 changed files with 46 additions and 33 deletions
  1. +2
    -13
      mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
  2. +0
    -3
      mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h
  3. +1
    -0
      mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h
  4. +42
    -17
      mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc
  5. +1
    -0
      mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h

+ 2
- 13
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc View File

@@ -146,16 +146,6 @@ std::vector<std::vector<int32_t>> PrepareBatchNorm(const std::shared_ptr<Graph>
return strategies; return strategies;
} }


std::vector<std::vector<int32_t>> PrepareSoftmaxWithLogits(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) {
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = graph->nodes[iter_graph].tensor_parm.tensor_str.str_h;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_c;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = graph->nodes[iter_graph].tensor_parm.tensor_str.str_n;
return strategies;
}

std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s) { std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s) {
std::vector<std::vector<int32_t>> strategies; std::vector<std::vector<int32_t>> strategies;
strategies.push_back(*s); strategies.push_back(*s);
@@ -299,9 +289,8 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
return PreparePReLU(graph, ops, iter_graph, iter_ops); return PreparePReLU(graph, ops, iter_graph, iter_ops);
} else if (type == BATCH_NORM) { } else if (type == BATCH_NORM) {
return PrepareBatchNorm(graph, ops, iter_graph, iter_ops); return PrepareBatchNorm(graph, ops, iter_graph, iter_ops);
} else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
return PrepareSoftmaxWithLogits(graph, ops, iter_graph, iter_ops);
} else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
} else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS ||
type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
} else { } else {
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);


+ 0
- 3
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h View File

@@ -40,9 +40,6 @@ std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &gra
std::vector<std::vector<int32_t>> PrepareBatchNorm(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> PrepareBatchNorm(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareSoftmaxWithLogits(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s); std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s);
std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<std::vector<int32_t>> &s); std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<std::vector<int32_t>> &s);
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s); std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s);


+ 1
- 0
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h View File

@@ -69,6 +69,7 @@ class Graph {
std::vector<size_t> node_in; std::vector<size_t> node_in;
// Nodes that point from this node // Nodes that point from this node
std::vector<size_t> node_out; std::vector<size_t> node_out;
std::vector<size_t> node_in_aux;
// Node Type Info: Application or Constant. Defined in enum <InfoType> . // Node Type Info: Application or Constant. Defined in enum <InfoType> .
InfoType info; InfoType info;
// Operator info. Defined in struct <OperatorRec> . // Operator info. Defined in struct <OperatorRec> .


+ 42
- 17
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc View File

@@ -171,21 +171,41 @@ void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> graph,
eli.push_back(graph->nodes[node_index].node_out[i]); eli.push_back(graph->nodes[node_index].node_out[i]);
} }
eli_list->push_back(eli); eli_list->push_back(eli);
for (auto input_index : graph->nodes[node_index].node_in) {
auto it = find(graph->nodes[input_index].node_out.begin(), graph->nodes[input_index].node_out.end(), node_index);
if (it != graph->nodes[input_index].node_out.end()) {
graph->nodes[input_index].node_out.erase(it);
for (auto output_index : graph->nodes[node_index].node_out) {
graph->nodes[input_index].node_out.push_back(output_index);
}
for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) {
auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out;
auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index);
if (it != incoming_outputs->end()) {
it = incoming_outputs->erase(it);
incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), graph->nodes[node_index].node_out.end());
}
}
for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) {
auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out;
auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index);
if (it != aux_incoming_outputs->end()) {
it = aux_incoming_outputs->erase(it);
aux_incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(),
graph->nodes[node_index].node_out.end());
} }
} }
for (auto output_index : graph->nodes[node_index].node_out) {
auto it = find(graph->nodes[output_index].node_in.begin(), graph->nodes[output_index].node_in.end(), node_index);
if (it != graph->nodes[output_index].node_in.end()) {
graph->nodes[output_index].node_in.erase(it);
for (auto input_index : graph->nodes[node_index].node_in) {
graph->nodes[output_index].node_in.push_back(input_index);
for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) {
auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in;
auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index);
if (it != outgoing_inputs->end()) {
if (graph->nodes[node_index].node_in.size() > 0) {
outgoing_inputs->at(std::distance(outgoing_inputs->begin(), it)) = graph->nodes[node_index].node_in[0];
for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) {
graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]);
}
for (size_t j = 1; j < graph->nodes[node_index].node_in_aux.size(); j++) {
graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(
graph->nodes[node_index].node_in_aux[j]);
}
} else {
outgoing_inputs->erase(it);
} }
} }
} }
@@ -206,10 +226,12 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
Eliminate_Aux(node_index, graph, eli_list); Eliminate_Aux(node_index, graph, eli_list);
} }
} }
index_list->reserve(graph->nodes.size()); index_list->reserve(graph->nodes.size());
for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) {
index_list->push_back(i); index_list->push_back(i);
} }
for (size_t i = 0; i < (size_t)eli_list->size(); i++) { for (size_t i = 0; i < (size_t)eli_list->size(); i++) {
if (eli_list->at(i)[0] >= index_list->size()) { if (eli_list->at(i)[0] >= index_list->size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
@@ -219,6 +241,7 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
index_list->at(j)--; index_list->at(j)--;
} }
} }
std::shared_ptr<Graph> new_graph(new Graph); std::shared_ptr<Graph> new_graph(new Graph);
for (size_t i = 0; i < graph->nodes.size(); i++) { for (size_t i = 0; i < graph->nodes.size(); i++) {
if (index_list->at(i) > SIZE_MAX / 2) { if (index_list->at(i) > SIZE_MAX / 2) {
@@ -226,11 +249,13 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
} }
new_graph->nodes.push_back(graph->nodes[i]); new_graph->nodes.push_back(graph->nodes[i]);
for (size_t j = 0; j < new_graph->nodes[index_list->at(i)].node_in.size(); j++) {
new_graph->nodes[index_list->at(i)].node_in[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_in[j]);
auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
for (size_t j = 0; j < node_in->size(); j++) {
node_in->at(j) = index_list->at(node_in->at(j));
} }
for (size_t j = 0; j < new_graph->nodes[index_list->at(i)].node_out.size(); j++) {
new_graph->nodes[index_list->at(i)].node_out[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_out[j]);
auto *node_out = &new_graph->nodes[index_list->at(i)].node_out;
for (size_t j = 0; j < node_out->size(); j++) {
node_out->at(j) = index_list->at(node_out->at(j));
} }
} }
return new_graph; return new_graph;


+ 1
- 0
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h View File

@@ -59,6 +59,7 @@ const std::map<std::string, OperatorType> DictOpType{


{PRELU, OperatorType::kRecPReLU}, {PRELU, OperatorType::kRecPReLU},


{L2_NORMALIZE, OperatorType::kRecElmWiseOp},
{TENSOR_ADD, OperatorType::kRecElmWiseOp}, {TENSOR_ADD, OperatorType::kRecElmWiseOp},
{SUB, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp},
{MUL, OperatorType::kRecElmWiseOp}, {MUL, OperatorType::kRecElmWiseOp},


Loading…
Cancel
Save