Browse Source

!3980 [AutoParallel] add GatherV2P strategy analysis for W&D

Merge pull request !3980 from Chong/wd
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
9940c723d5
2 changed files with 120 additions and 15 deletions
  1. +117
    -15
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
  2. +3
    -0
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h

+ 117
- 15
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc View File

@@ -176,21 +176,102 @@ Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
s[axis] = 1;
strategies.push_back(s);

auto pos = ops[iter_ops]->name().find("Info");
auto name = ops[iter_ops]->name().substr(0, pos);
if (name == "GatherV2") {
return strategies;
return strategies;
}

Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
Strategys strategies;

auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape();
Dimensions index(output_shape.size() - 1, 0);
for (size_t i = 0; i < index.size(); i++) {
index[i] = i;
}
std::sort(index.begin(), index.end(),
[&output_shape](const int &a, const int &b) { return (output_shape[a + 1] > output_shape[b + 1]); });
std::transform(std::begin(index), std::end(index), std::begin(index), [](int x) { return x + 1; });
index.insert(index.begin(), 0);

Dimensions s_indices;
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) {
s_indices.push_back(1);
Dimensions strategie(output_shape.size(), 1);
size_t num_device = g_device_manager->DeviceNum();
size_t cut = 1;
for (size_t i = 0; i < index.size(); i++) {
while (output_shape[index[i]] % 2 == 0 && output_shape[index[i]] > 0 && cut < num_device) {
output_shape[index[i]] /= 2;
cut *= 2;
strategie[index[i]] *= 2;
}
if (cut == num_device) {
break;
}
}

auto axis_input = GetValue<int>(ops[iter_ops]->input_value().at(2));
if (axis_input < 0) {
axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
}
int32_t axis = axis_input;
if (axis >= SizeToInt(s.size())) {
MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range.";
}
if (axis == 0) {
s.clear();
s.push_back(1);
for (size_t i = 1; i < ops[iter_ops]->inputs_tensor_info()[0].shape().size(); i++) {
s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[1].shape().size() - 1 + i]);
}
strategies.push_back(s);
s.clear();
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) {
s.push_back(strategie[i]);
}
strategies.push_back(s);
} else if (axis == 1) {
s.clear();
s.push_back(strategie[0]);
s.push_back(1);
strategies.push_back(s);
s.clear();
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) {
s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[0].shape().size() - 1 + i]);
}
strategies.push_back(s);
} else {
MS_LOG(EXCEPTION) << "Failure: GatherV2's axis is neither 0 nor 1.";
}
strategies.push_back(s_indices);

return strategies;
}

Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index) {
auto output_shape = ops[incoming_op_index]->outputs_tensor_info()[0].shape();
Dimensions index(output_shape.size() - 1, 0);
for (size_t i = 0; i < index.size(); i++) {
index[i] = i;
}
std::sort(index.begin(), index.end(),
[&output_shape](const int &a, const int &b) { return (output_shape[a + 1] > output_shape[b + 1]); });
std::transform(std::begin(index), std::end(index), std::begin(index), [](int x) { return x + 1; });
index.insert(index.begin(), 0);

Dimensions strategie(output_shape.size(), 1);
size_t num_device = g_device_manager->DeviceNum();
size_t cut = 1;
for (size_t i = 0; i < index.size(); i++) {
while (output_shape[index[i]] % 2 == 0 && output_shape[index[i]] > 0 && cut < num_device) {
output_shape[index[i]] /= 2;
cut *= 2;
strategie[index[i]] *= 2;
}
if (cut == num_device) {
break;
}
}

return strategie;
}

Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions s) {
int32_t axis = 0;
@@ -401,10 +482,20 @@ Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &grap
Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index) {
Dimensions s;
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 ||
ops[incoming_op_index]->type() == TRANSPOSE) {
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == TRANSPOSE) {
return s;
}
if (ops[incoming_op_index]->type() == GATHERV2) {
auto pos = ops[incoming_op_index]->name().find("Info");
auto name = ops[incoming_op_index]->name().substr(0, pos);
if (name == "GatherV2") {
return s;
} else if (name == "GatherV2P") {
return PrepareGatherV2POutputStrategy(ops, incoming_op_index);
} else {
MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl;
}
}
auto strategy = ops[incoming_op_index]->selected_strategy();
if (strategy->GetInputNumber() == 0) {
return s;
@@ -495,10 +586,13 @@ Dimensions GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, con
if (input_value.back()->isa<ValueTuple>()) {
auto attr_axis = GetValue<std::vector<int>>(input_value.back());
if (attr_axis.empty()) {
MS_LOG(EXCEPTION) << "Failure: This output is a 0-D tensor." << std::endl;
}
for (auto &axis : attr_axis) {
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis);
for (size_t i = 0; i < input_dim; i++) {
dim_list.push_back(SizeToInt(i));
}
} else {
for (auto &axis : attr_axis) {
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis);
}
}
} else if (input_value.back()->isa<Int32Imm>()) {
int axis = GetValue<int>(input_value.back());
@@ -625,7 +719,15 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
return PrepareBiasAdd(s_ptr);
}
if (ops[iter_ops]->type() == GATHERV2) {
return PrepareGatherV2(ops, iter_ops, basic_stra);
auto pos = ops[iter_ops]->name().find("Info");
auto name = ops[iter_ops]->name().substr(0, pos);
if (name == "GatherV2") {
return PrepareGatherV2(ops, iter_ops, basic_stra);
} else if (name == "GatherV2P") {
return PrepareGatherV2P(ops, iter_ops, basic_stra);
} else {
MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl;
}
}
if (ops[iter_ops]->type() == L2_NORMALIZE) {
return PrepareL2Normalize(ops, iter_ops, basic_stra);


+ 3
- 0
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h View File

@@ -37,6 +37,9 @@ Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index);
Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions s);
Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,


Loading…
Cancel
Save