From 39790ccf66f578c0ec35a99dfd31590b2461a2b6 Mon Sep 17 00:00:00 2001 From: hongxing Date: Thu, 11 Jun 2020 17:45:37 +0200 Subject: [PATCH] Optimize code --- .../rec_core/rec_generate_strategy.cc | 15 +---- .../rec_core/rec_generate_strategy.h | 3 - .../auto_parallel/rec_core/rec_graph.h | 1 + .../auto_parallel/rec_core/rec_parse_graph.cc | 59 +++++++++++++------ .../auto_parallel/rec_core/rec_parse_graph.h | 1 + 5 files changed, 46 insertions(+), 33 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 4bc183b1a2..6ae5a0d5eb 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -146,16 +146,6 @@ std::vector> PrepareBatchNorm(const std::shared_ptr return strategies; } -std::vector> PrepareSoftmaxWithLogits(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - std::vector> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = graph->nodes[iter_graph].tensor_parm.tensor_str.str_h; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_c; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = graph->nodes[iter_graph].tensor_parm.tensor_str.str_n; - return strategies; -} - std::vector> PrepareBiasAdd(const std::shared_ptr> &s) { std::vector> strategies; strategies.push_back(*s); @@ -299,9 +289,8 @@ std::vector> PrepareStrategy(const std::shared_ptr & return PreparePReLU(graph, ops, iter_graph, iter_ops); } else if (type == BATCH_NORM) { return PrepareBatchNorm(graph, ops, iter_graph, iter_ops); - } else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { - return PrepareSoftmaxWithLogits(graph, ops, iter_graph, iter_ops); - } else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { + } else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS || + type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); } else { return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 2b76c59728..1e5d4d95d0 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -40,9 +40,6 @@ std::vector> PreparePReLU(const std::shared_ptr &gra std::vector> PrepareBatchNorm(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareSoftmaxWithLogits(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); std::vector> PrepareBiasAdd(const std::shared_ptr> &s); std::vector> PrepareOneHot(const std::shared_ptr> &s); std::vector> PrepareGatherV2(const std::shared_ptr> &s); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h index 879e22cb1f..d578bd82ef 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h @@ -69,6 +69,7 @@ class Graph { std::vector node_in; // Nodes that point from this node std::vector node_out; + std::vector node_in_aux; // Node Type Info: Application or Constant. Defined in enum . InfoType info; // Operator info. Defined in struct . diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 979f987225..2aa9bddcc1 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -171,21 +171,41 @@ void Eliminate_Aux(const size_t node_index, const std::shared_ptr graph, eli.push_back(graph->nodes[node_index].node_out[i]); } eli_list->push_back(eli); - for (auto input_index : graph->nodes[node_index].node_in) { - auto it = find(graph->nodes[input_index].node_out.begin(), graph->nodes[input_index].node_out.end(), node_index); - if (it != graph->nodes[input_index].node_out.end()) { - graph->nodes[input_index].node_out.erase(it); - for (auto output_index : graph->nodes[node_index].node_out) { - graph->nodes[input_index].node_out.push_back(output_index); - } + + for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) { + auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out; + auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index); + if (it != incoming_outputs->end()) { + it = incoming_outputs->erase(it); + incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), graph->nodes[node_index].node_out.end()); + } + } + + for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) { + auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out; + auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index); + if (it != aux_incoming_outputs->end()) { + it = aux_incoming_outputs->erase(it); + aux_incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), + graph->nodes[node_index].node_out.end()); } } - for (auto output_index : graph->nodes[node_index].node_out) { - auto it = find(graph->nodes[output_index].node_in.begin(), graph->nodes[output_index].node_in.end(), node_index); - if (it != graph->nodes[output_index].node_in.end()) { - graph->nodes[output_index].node_in.erase(it); - for (auto input_index : graph->nodes[node_index].node_in) { - graph->nodes[output_index].node_in.push_back(input_index); + + for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) { + auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in; + auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index); + if (it != outgoing_inputs->end()) { + if (graph->nodes[node_index].node_in.size() > 0) { + outgoing_inputs->at(std::distance(outgoing_inputs->begin(), it)) = graph->nodes[node_index].node_in[0]; + for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) { + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]); + } + for (size_t j = 1; j < graph->nodes[node_index].node_in_aux.size(); j++) { + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back( + graph->nodes[node_index].node_in_aux[j]); + } + } else { + outgoing_inputs->erase(it); } } } @@ -206,10 +226,12 @@ std::shared_ptr EliminateGraph(const std::shared_ptr graph, Eliminate_Aux(node_index, graph, eli_list); } } + index_list->reserve(graph->nodes.size()); for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { index_list->push_back(i); } + for (size_t i = 0; i < (size_t)eli_list->size(); i++) { if (eli_list->at(i)[0] >= index_list->size()) { MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; @@ -219,6 +241,7 @@ std::shared_ptr EliminateGraph(const std::shared_ptr graph, index_list->at(j)--; } } + std::shared_ptr new_graph(new Graph); for (size_t i = 0; i < graph->nodes.size(); i++) { if (index_list->at(i) > SIZE_MAX / 2) { @@ -226,11 +249,13 @@ std::shared_ptr EliminateGraph(const std::shared_ptr graph, } new_graph->nodes.push_back(graph->nodes[i]); - for (size_t j = 0; j < new_graph->nodes[index_list->at(i)].node_in.size(); j++) { - new_graph->nodes[index_list->at(i)].node_in[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_in[j]); + auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; + for (size_t j = 0; j < node_in->size(); j++) { + node_in->at(j) = index_list->at(node_in->at(j)); } - for (size_t j = 0; j < new_graph->nodes[index_list->at(i)].node_out.size(); j++) { - new_graph->nodes[index_list->at(i)].node_out[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_out[j]); + auto *node_out = &new_graph->nodes[index_list->at(i)].node_out; + for (size_t j = 0; j < node_out->size(); j++) { + node_out->at(j) = index_list->at(node_out->at(j)); } } return new_graph; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index 1b51e4d9b0..f39546dffc 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -59,6 +59,7 @@ const std::map DictOpType{ {PRELU, OperatorType::kRecPReLU}, + {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, {TENSOR_ADD, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp}, {MUL, OperatorType::kRecElmWiseOp},