Browse Source

fix bug and support new op

tags/v0.5.0-beta
hongxing 5 years ago
parent
commit
26d05be808
3 changed files with 33 additions and 18 deletions
  1. +30
    -16
      mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
  2. +2
    -2
      mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h
  3. +1
    -0
      mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h

+ 30
- 16
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc View File

@@ -78,8 +78,8 @@ std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::s
return strategies; return strategies;
} }


std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s) {
std::vector<std::vector<int32_t>> PrepareScalarInputOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s) {
std::vector<std::vector<int32_t>> strategies; std::vector<std::vector<int32_t>> strategies;


auto dev_num = g_device_manager->DeviceNum(); auto dev_num = g_device_manager->DeviceNum();
@@ -190,12 +190,16 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::vector<std
std::vector<int32_t> s; std::vector<int32_t> s;
size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
for (size_t dim = 0; dim < input_size; dim++) { for (size_t dim = 0; dim < input_size; dim++) {
if (dim == 0 && input_size == 4) {
size_t max_device_num = g_device_manager->DeviceNum();
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0];
s.push_back(std::min(max_device_num, target_tensor_batch));
if (input_size == 1 || input_size == 2 || input_size == 4) {
if (dim == 0) {
size_t max_device_num = g_device_manager->DeviceNum();
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0];
s.push_back(std::min(max_device_num, target_tensor_batch));
} else {
s.push_back(1);
}
} else { } else {
s.push_back(1);
MS_LOG(ERROR) << "Tensor's shape is unknown.";
} }
} }


@@ -239,6 +243,8 @@ void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph,
std::vector<std::vector<int32_t>> strategies; std::vector<std::vector<int32_t>> strategies;
size_t iter_graph = index_list->at(iter_ops); size_t iter_graph = index_list->at(iter_ops);
if (iter_graph == SIZE_MAX) { if (iter_graph == SIZE_MAX) {
StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
continue; continue;
} }
strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops);
@@ -389,7 +395,7 @@ std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::share
std::vector<int32_t> s_Reduce; std::vector<int32_t> s_Reduce;
std::vector<int32_t> axis_list; std::vector<int32_t> axis_list;
for (size_t i = 0; i < s.size(); i++) { for (size_t i = 0; i < s.size(); i++) {
axis_list.push_back(i + 1);
axis_list.push_back(i);
} }
auto dim_list = GetDimList(ops, incoming_op_index); auto dim_list = GetDimList(ops, incoming_op_index);
for (auto axis : dim_list) { for (auto axis : dim_list) {
@@ -400,7 +406,7 @@ std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::share
axis_list.erase(it); axis_list.erase(it);
} }
for (size_t i = 0; i < (size_t)axis_list.size(); i++) { for (size_t i = 0; i < (size_t)axis_list.size(); i++) {
s_Reduce.push_back(s[axis_list[i] - 1]);
s_Reduce.push_back(s[axis_list[i]]);
} }
return s_Reduce; return s_Reduce;
} }
@@ -418,8 +424,6 @@ std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::sh
ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) {
s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s);
} }
} else {
no_stra_op_list->push_back(iter_ops);
} }
return s; return s;
} }
@@ -428,12 +432,18 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
const size_t iter_ops, std::vector<int32_t> s) { const size_t iter_ops, std::vector<int32_t> s) {
std::vector<int32_t> s_empty = {}; std::vector<int32_t> s_empty = {};
std::vector<std::vector<int32_t>> stra; std::vector<std::vector<int32_t>> stra;

if (s.size() == 0) { if (s.size() == 0) {
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
iter_op_inputs++) {
stra.push_back(s);
}
return stra; return stra;
} }

MS_EXCEPTION_IF_NULL(ops[iter_ops]); MS_EXCEPTION_IF_NULL(ops[iter_ops]);
if (ops[iter_ops]->type() == BIAS_ADD) {
return PrepareBiasAdd(ops, iter_ops, s);
if (ops[iter_ops]->type() == BIAS_ADD || ops[iter_ops]->type() == PRELU) {
return PrepareScalarInputOperator(ops, iter_ops, s);
} }
if (ops[iter_ops]->type() == ONEHOT) { if (ops[iter_ops]->type() == ONEHOT) {
return PrepareOneHot(s); return PrepareOneHot(s);
@@ -504,10 +514,14 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> grap
} else { } else {
s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index, iter_ops, no_stra_op_list); s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index, iter_ops, no_stra_op_list);
} }
} else {
}

if (s.size() == 0) {
no_stra_op_list->push_back(iter_ops); no_stra_op_list->push_back(iter_ops);
} else {
stra = GenerateStrategiesFromStrategy(ops, iter_ops, s);
} }
stra = GenerateStrategiesFromStrategy(ops, iter_ops, s);
StrategyPtr sp = std::make_shared<Strategy>(0, stra); StrategyPtr sp = std::make_shared<Strategy>(0, stra);
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
} }
@@ -541,7 +555,7 @@ std::vector<int32_t> ModifyStrategyIfReduceOutgoing(const std::vector<std::share
size_t s_index = 0; size_t s_index = 0;
size_t dim_list_index = 0; size_t dim_list_index = 0;
for (size_t i = 0; i < (size_t)(s.size() + dim_list.size()); i++) { for (size_t i = 0; i < (size_t)(s.size() + dim_list.size()); i++) {
if ((i + 1) == (size_t)dim_list[dim_list_index]) {
if (i == (size_t)dim_list[dim_list_index]) {
s_Reduce.push_back(1); s_Reduce.push_back(1);
dim_list_index++; dim_list_index++;
} else { } else {


+ 2
- 2
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h View File

@@ -36,8 +36,8 @@ std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &gr
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops); const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> PrepareScalarInputOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s); std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s);
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,


+ 1
- 0
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h View File

@@ -55,6 +55,7 @@ const std::map<std::string, OperatorType> DictOpType{
{"HSigmoid", OperatorType::kRecReLU}, {"HSigmoid", OperatorType::kRecReLU},
{GELU, OperatorType::kRecReLU}, {GELU, OperatorType::kRecReLU},
{TANH, OperatorType::kRecReLU}, {TANH, OperatorType::kRecReLU},
{PRELU, OperatorType::kRecReLU},


{TENSOR_ADD, OperatorType::kRecElmWiseOp}, {TENSOR_ADD, OperatorType::kRecElmWiseOp},
{SUB, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp},


Loading…
Cancel
Save