From c68dc39d6c5f4b01e4bc09dc6f41a2245757b1a4 Mon Sep 17 00:00:00 2001 From: hongxing Date: Tue, 4 Aug 2020 21:39:56 +0200 Subject: [PATCH] support GatherV2P and fix Reduce bug --- .../rec_core/rec_generate_strategy.cc | 132 ++++++++++++++++-- .../rec_core/rec_generate_strategy.h | 3 + 2 files changed, 120 insertions(+), 15 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index a2b1e0b397..e51bd579f3 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -176,21 +176,102 @@ Strategys PrepareGatherV2(const std::vector> &ops, s[axis] = 1; strategies.push_back(s); - auto pos = ops[iter_ops]->name().find("Info"); - auto name = ops[iter_ops]->name().substr(0, pos); - if (name == "GatherV2") { - return strategies; + return strategies; +} + +Strategys PrepareGatherV2P(const std::vector> &ops, const size_t iter_ops, Dimensions s) { + Strategys strategies; + + auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape(); + Dimensions index(output_shape.size() - 1, 0); + for (size_t i = 0; i < index.size(); i++) { + index[i] = i; } + std::sort(index.begin(), index.end(), + [&output_shape](const int &a, const int &b) { return (output_shape[a + 1] > output_shape[b + 1]); }); + std::transform(std::begin(index), std::end(index), std::begin(index), [](int x) { return x + 1; }); + index.insert(index.begin(), 0); - Dimensions s_indices; - for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { - s_indices.push_back(1); + Dimensions strategie(output_shape.size(), 1); + size_t num_device = g_device_manager->DeviceNum(); + size_t cut = 1; + for (size_t i = 0; i < index.size(); i++) { + while (output_shape[index[i]] % 2 == 0 && output_shape[index[i]] > 0 && cut < num_device) { + output_shape[index[i]] /= 2; + cut *= 2; + strategie[index[i]] *= 2; + } + if (cut == num_device) { + break; + } + } + + auto axis_input = GetValue(ops[iter_ops]->input_value().at(2)); + if (axis_input < 0) { + axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); + } + int32_t axis = axis_input; + if (axis >= SizeToInt(s.size())) { + MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; + } + if (axis == 0) { + s.clear(); + s.push_back(1); + for (size_t i = 1; i < ops[iter_ops]->inputs_tensor_info()[0].shape().size(); i++) { + s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[1].shape().size() - 1 + i]); + } + strategies.push_back(s); + s.clear(); + for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { + s.push_back(strategie[i]); + } + strategies.push_back(s); + } else if (axis == 1) { + s.clear(); + s.push_back(strategie[0]); + s.push_back(1); + strategies.push_back(s); + s.clear(); + for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { + s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[0].shape().size() - 1 + i]); + } + strategies.push_back(s); + } else { + MS_LOG(EXCEPTION) << "Failure: GatherV2's axis is neither 0 nor 1."; } - strategies.push_back(s_indices); return strategies; } +Dimensions PrepareGatherV2POutputStrategy(const std::vector> &ops, + const size_t incoming_op_index) { + auto output_shape = ops[incoming_op_index]->outputs_tensor_info()[0].shape(); + Dimensions index(output_shape.size() - 1, 0); + for (size_t i = 0; i < index.size(); i++) { + index[i] = i; + } + std::sort(index.begin(), index.end(), + [&output_shape](const int &a, const int &b) { return (output_shape[a + 1] > output_shape[b + 1]); }); + std::transform(std::begin(index), std::end(index), std::begin(index), [](int x) { return x + 1; }); + index.insert(index.begin(), 0); + + Dimensions strategie(output_shape.size(), 1); + size_t num_device = g_device_manager->DeviceNum(); + size_t cut = 1; + for (size_t i = 0; i < index.size(); i++) { + while (output_shape[index[i]] % 2 == 0 && output_shape[index[i]] > 0 && cut < num_device) { + output_shape[index[i]] /= 2; + cut *= 2; + strategie[index[i]] *= 2; + } + if (cut == num_device) { + break; + } + } + + return strategie; +} + Strategys PrepareL2Normalize(const std::vector> &ops, const size_t iter_ops, Dimensions s) { int32_t axis = 0; @@ -401,10 +482,20 @@ Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr &grap Dimensions PrepareIncomingOperatorInputStrategy(const std::vector> &ops, const size_t incoming_op_index) { Dimensions s; - if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || - ops[incoming_op_index]->type() == TRANSPOSE) { + if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == TRANSPOSE) { return s; } + if (ops[incoming_op_index]->type() == GATHERV2) { + auto pos = ops[incoming_op_index]->name().find("Info"); + auto name = ops[incoming_op_index]->name().substr(0, pos); + if (name == "GatherV2") { + return s; + } else if (name == "GatherV2P") { + return PrepareGatherV2POutputStrategy(ops, incoming_op_index); + } else { + MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl; + } + } auto strategy = ops[incoming_op_index]->selected_strategy(); if (strategy->GetInputNumber() == 0) { return s; @@ -495,10 +586,13 @@ Dimensions GetDimList(const std::vector> &ops, con if (input_value.back()->isa()) { auto attr_axis = GetValue>(input_value.back()); if (attr_axis.empty()) { - MS_LOG(EXCEPTION) << "Failure: This output is a 0-D tensor." << std::endl; - } - for (auto &axis : attr_axis) { - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + for (size_t i = 0; i < input_dim; i++) { + dim_list.push_back(SizeToInt(i)); + } + } else { + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } } } else if (input_value.back()->isa()) { int axis = GetValue(input_value.back()); @@ -625,7 +719,15 @@ Strategys GenerateStrategiesFromStrategy(const std::vectortype() == GATHERV2) { - return PrepareGatherV2(ops, iter_ops, basic_stra); + auto pos = ops[iter_ops]->name().find("Info"); + auto name = ops[iter_ops]->name().substr(0, pos); + if (name == "GatherV2") { + return PrepareGatherV2(ops, iter_ops, basic_stra); + } else if (name == "GatherV2P") { + return PrepareGatherV2P(ops, iter_ops, basic_stra); + } else { + MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl; + } } if (ops[iter_ops]->type() == L2_NORMALIZE) { return PrepareL2Normalize(ops, iter_ops, basic_stra); diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h index ab8aa01e99..2263deb588 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -37,6 +37,9 @@ Strategys PrepareBiasAdd(const std::shared_ptr &s); Strategys PrepareOneHot(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); Strategys PrepareGatherV2(const std::vector> &ops, const size_t iter_ops, Dimensions s); +Strategys PrepareGatherV2P(const std::vector> &ops, const size_t iter_ops, Dimensions s); +Dimensions PrepareGatherV2POutputStrategy(const std::vector> &ops, + const size_t incoming_op_index); Strategys PrepareL2Normalize(const std::vector> &ops, const size_t iter_ops, Dimensions s); Strategys MakeRecSearchStrategy(const std::shared_ptr &graph,