|
|
@@ -176,21 +176,102 @@ Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
s[axis] = 1; |
|
|
s[axis] = 1; |
|
|
strategies.push_back(s); |
|
|
strategies.push_back(s); |
|
|
|
|
|
|
|
|
auto pos = ops[iter_ops]->name().find("Info"); |
|
|
|
|
|
auto name = ops[iter_ops]->name().substr(0, pos); |
|
|
|
|
|
if (name == "GatherV2") { |
|
|
|
|
|
return strategies; |
|
|
|
|
|
|
|
|
return strategies; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) { |
|
|
|
|
|
Strategys strategies; |
|
|
|
|
|
|
|
|
|
|
|
auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape(); |
|
|
|
|
|
Dimensions index(output_shape.size() - 1, 0); |
|
|
|
|
|
for (size_t i = 0; i < index.size(); i++) { |
|
|
|
|
|
index[i] = i; |
|
|
} |
|
|
} |
|
|
|
|
|
std::sort(index.begin(), index.end(), |
|
|
|
|
|
[&output_shape](const int &a, const int &b) { return (output_shape[a + 1] > output_shape[b + 1]); }); |
|
|
|
|
|
std::transform(std::begin(index), std::end(index), std::begin(index), [](int x) { return x + 1; }); |
|
|
|
|
|
index.insert(index.begin(), 0); |
|
|
|
|
|
|
|
|
Dimensions s_indices; |
|
|
|
|
|
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { |
|
|
|
|
|
s_indices.push_back(1); |
|
|
|
|
|
|
|
|
Dimensions strategie(output_shape.size(), 1); |
|
|
|
|
|
size_t num_device = g_device_manager->DeviceNum(); |
|
|
|
|
|
size_t cut = 1; |
|
|
|
|
|
for (size_t i = 0; i < index.size(); i++) { |
|
|
|
|
|
while (output_shape[index[i]] % 2 == 0 && output_shape[index[i]] > 0 && cut < num_device) { |
|
|
|
|
|
output_shape[index[i]] /= 2; |
|
|
|
|
|
cut *= 2; |
|
|
|
|
|
strategie[index[i]] *= 2; |
|
|
|
|
|
} |
|
|
|
|
|
if (cut == num_device) { |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto axis_input = GetValue<int>(ops[iter_ops]->input_value().at(2)); |
|
|
|
|
|
if (axis_input < 0) { |
|
|
|
|
|
axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); |
|
|
|
|
|
} |
|
|
|
|
|
int32_t axis = axis_input; |
|
|
|
|
|
if (axis >= SizeToInt(s.size())) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; |
|
|
|
|
|
} |
|
|
|
|
|
if (axis == 0) { |
|
|
|
|
|
s.clear(); |
|
|
|
|
|
s.push_back(1); |
|
|
|
|
|
for (size_t i = 1; i < ops[iter_ops]->inputs_tensor_info()[0].shape().size(); i++) { |
|
|
|
|
|
s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[1].shape().size() - 1 + i]); |
|
|
|
|
|
} |
|
|
|
|
|
strategies.push_back(s); |
|
|
|
|
|
s.clear(); |
|
|
|
|
|
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { |
|
|
|
|
|
s.push_back(strategie[i]); |
|
|
|
|
|
} |
|
|
|
|
|
strategies.push_back(s); |
|
|
|
|
|
} else if (axis == 1) { |
|
|
|
|
|
s.clear(); |
|
|
|
|
|
s.push_back(strategie[0]); |
|
|
|
|
|
s.push_back(1); |
|
|
|
|
|
strategies.push_back(s); |
|
|
|
|
|
s.clear(); |
|
|
|
|
|
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { |
|
|
|
|
|
s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[0].shape().size() - 1 + i]); |
|
|
|
|
|
} |
|
|
|
|
|
strategies.push_back(s); |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: GatherV2's axis is neither 0 nor 1."; |
|
|
} |
|
|
} |
|
|
strategies.push_back(s_indices); |
|
|
|
|
|
|
|
|
|
|
|
return strategies; |
|
|
return strategies; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
|
|
const size_t incoming_op_index) { |
|
|
|
|
|
auto output_shape = ops[incoming_op_index]->outputs_tensor_info()[0].shape(); |
|
|
|
|
|
Dimensions index(output_shape.size() - 1, 0); |
|
|
|
|
|
for (size_t i = 0; i < index.size(); i++) { |
|
|
|
|
|
index[i] = i; |
|
|
|
|
|
} |
|
|
|
|
|
std::sort(index.begin(), index.end(), |
|
|
|
|
|
[&output_shape](const int &a, const int &b) { return (output_shape[a + 1] > output_shape[b + 1]); }); |
|
|
|
|
|
std::transform(std::begin(index), std::end(index), std::begin(index), [](int x) { return x + 1; }); |
|
|
|
|
|
index.insert(index.begin(), 0); |
|
|
|
|
|
|
|
|
|
|
|
Dimensions strategie(output_shape.size(), 1); |
|
|
|
|
|
size_t num_device = g_device_manager->DeviceNum(); |
|
|
|
|
|
size_t cut = 1; |
|
|
|
|
|
for (size_t i = 0; i < index.size(); i++) { |
|
|
|
|
|
while (output_shape[index[i]] % 2 == 0 && output_shape[index[i]] > 0 && cut < num_device) { |
|
|
|
|
|
output_shape[index[i]] /= 2; |
|
|
|
|
|
cut *= 2; |
|
|
|
|
|
strategie[index[i]] *= 2; |
|
|
|
|
|
} |
|
|
|
|
|
if (cut == num_device) { |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return strategie; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, |
|
|
Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, |
|
|
Dimensions s) { |
|
|
Dimensions s) { |
|
|
int32_t axis = 0; |
|
|
int32_t axis = 0; |
|
|
@@ -401,10 +482,20 @@ Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &grap |
|
|
Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const size_t incoming_op_index) { |
|
|
const size_t incoming_op_index) { |
|
|
Dimensions s; |
|
|
Dimensions s; |
|
|
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || |
|
|
|
|
|
ops[incoming_op_index]->type() == TRANSPOSE) { |
|
|
|
|
|
|
|
|
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == TRANSPOSE) { |
|
|
return s; |
|
|
return s; |
|
|
} |
|
|
} |
|
|
|
|
|
if (ops[incoming_op_index]->type() == GATHERV2) { |
|
|
|
|
|
auto pos = ops[incoming_op_index]->name().find("Info"); |
|
|
|
|
|
auto name = ops[incoming_op_index]->name().substr(0, pos); |
|
|
|
|
|
if (name == "GatherV2") { |
|
|
|
|
|
return s; |
|
|
|
|
|
} else if (name == "GatherV2P") { |
|
|
|
|
|
return PrepareGatherV2POutputStrategy(ops, incoming_op_index); |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
auto strategy = ops[incoming_op_index]->selected_strategy(); |
|
|
auto strategy = ops[incoming_op_index]->selected_strategy(); |
|
|
if (strategy->GetInputNumber() == 0) { |
|
|
if (strategy->GetInputNumber() == 0) { |
|
|
return s; |
|
|
return s; |
|
|
@@ -495,10 +586,13 @@ Dimensions GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, con |
|
|
if (input_value.back()->isa<ValueTuple>()) { |
|
|
if (input_value.back()->isa<ValueTuple>()) { |
|
|
auto attr_axis = GetValue<std::vector<int>>(input_value.back()); |
|
|
auto attr_axis = GetValue<std::vector<int>>(input_value.back()); |
|
|
if (attr_axis.empty()) { |
|
|
if (attr_axis.empty()) { |
|
|
MS_LOG(EXCEPTION) << "Failure: This output is a 0-D tensor." << std::endl; |
|
|
|
|
|
} |
|
|
|
|
|
for (auto &axis : attr_axis) { |
|
|
|
|
|
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); |
|
|
|
|
|
|
|
|
for (size_t i = 0; i < input_dim; i++) { |
|
|
|
|
|
dim_list.push_back(SizeToInt(i)); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
for (auto &axis : attr_axis) { |
|
|
|
|
|
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} else if (input_value.back()->isa<Int32Imm>()) { |
|
|
} else if (input_value.back()->isa<Int32Imm>()) { |
|
|
int axis = GetValue<int>(input_value.back()); |
|
|
int axis = GetValue<int>(input_value.back()); |
|
|
@@ -625,7 +719,15 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera |
|
|
return PrepareBiasAdd(s_ptr); |
|
|
return PrepareBiasAdd(s_ptr); |
|
|
} |
|
|
} |
|
|
if (ops[iter_ops]->type() == GATHERV2) { |
|
|
if (ops[iter_ops]->type() == GATHERV2) { |
|
|
return PrepareGatherV2(ops, iter_ops, basic_stra); |
|
|
|
|
|
|
|
|
auto pos = ops[iter_ops]->name().find("Info"); |
|
|
|
|
|
auto name = ops[iter_ops]->name().substr(0, pos); |
|
|
|
|
|
if (name == "GatherV2") { |
|
|
|
|
|
return PrepareGatherV2(ops, iter_ops, basic_stra); |
|
|
|
|
|
} else if (name == "GatherV2P") { |
|
|
|
|
|
return PrepareGatherV2P(ops, iter_ops, basic_stra); |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
if (ops[iter_ops]->type() == L2_NORMALIZE) { |
|
|
if (ops[iter_ops]->type() == L2_NORMALIZE) { |
|
|
return PrepareL2Normalize(ops, iter_ops, basic_stra); |
|
|
return PrepareL2Normalize(ops, iter_ops, basic_stra); |
|
|
|