diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 6b5bb97208..8c99df8345 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -78,8 +78,8 @@ std::vector> PrepareVirtualDataset(const std::vector> PrepareBiasAdd(const std::vector> &ops, - const size_t iter_ops, std::vector s) { +std::vector> PrepareScalarInputOperator(const std::vector> &ops, + const size_t iter_ops, std::vector s) { std::vector> strategies; auto dev_num = g_device_manager->DeviceNum(); @@ -190,12 +190,16 @@ std::vector> MakeDataParallelStrategy(const std::vector s; size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); for (size_t dim = 0; dim < input_size; dim++) { - if (dim == 0 && input_size == 4) { - size_t max_device_num = g_device_manager->DeviceNum(); - size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; - s.push_back(std::min(max_device_num, target_tensor_batch)); + if (input_size == 1 || input_size == 2 || input_size == 4) { + if (dim == 0) { + size_t max_device_num = g_device_manager->DeviceNum(); + size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; + s.push_back(std::min(max_device_num, target_tensor_batch)); + } else { + s.push_back(1); + } } else { - s.push_back(1); + MS_LOG(ERROR) << "Tensor's shape is unknown."; } } @@ -239,6 +243,8 @@ void GeneratePartitionedOperatorStrategy(const std::shared_ptr graph, std::vector> strategies; size_t iter_graph = index_list->at(iter_ops); if (iter_graph == SIZE_MAX) { + StrategyPtr sp = std::make_shared(0, strategies); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); continue; } strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); @@ -389,7 +395,7 @@ std::vector ModifyStrategyIfReduceIncoming(const std::vector s_Reduce; std::vector axis_list; for (size_t i = 0; i < s.size(); i++) { - axis_list.push_back(i + 1); + axis_list.push_back(i); } auto dim_list = GetDimList(ops, incoming_op_index); for (auto axis : dim_list) { @@ -400,7 +406,7 @@ std::vector ModifyStrategyIfReduceIncoming(const std::vector CopyIncomingOperatorInputStrategy(const std::vectortype() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); } - } else { - no_stra_op_list->push_back(iter_ops); } return s; } @@ -428,12 +432,18 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect const size_t iter_ops, std::vector s) { std::vector s_empty = {}; std::vector> stra; + if (s.size() == 0) { + for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); + iter_op_inputs++) { + stra.push_back(s); + } return stra; } + MS_EXCEPTION_IF_NULL(ops[iter_ops]); - if (ops[iter_ops]->type() == BIAS_ADD) { - return PrepareBiasAdd(ops, iter_ops, s); + if (ops[iter_ops]->type() == BIAS_ADD || ops[iter_ops]->type() == PRELU) { + return PrepareScalarInputOperator(ops, iter_ops, s); } if (ops[iter_ops]->type() == ONEHOT) { return PrepareOneHot(s); @@ -504,10 +514,14 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr grap } else { s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index, iter_ops, no_stra_op_list); } - } else { + } + + if (s.size() == 0) { no_stra_op_list->push_back(iter_ops); + } else { + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); } - stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); + StrategyPtr sp = std::make_shared(0, stra); ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); } @@ -541,7 +555,7 @@ std::vector ModifyStrategyIfReduceOutgoing(const std::vector> PrepareMatMul(const std::shared_ptr &gr const size_t iter_graph, const size_t iter_ops); std::vector> PrepareVirtualDataset(const std::vector> &ops, const size_t iter_ops); -std::vector> PrepareBiasAdd(const std::vector> &ops, - const size_t iter_ops, std::vector s); +std::vector> PrepareScalarInputOperator(const std::vector> &ops, + const size_t iter_ops, std::vector s); std::vector> PrepareOneHot(std::vector s); std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, const std::vector> &ops, diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index f3b0fbe247..e6398b9556 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -55,6 +55,7 @@ const std::map DictOpType{ {"HSigmoid", OperatorType::kRecReLU}, {GELU, OperatorType::kRecReLU}, {TANH, OperatorType::kRecReLU}, + {PRELU, OperatorType::kRecReLU}, {TENSOR_ADD, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp},