diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 50f9a93ad6..8ea7152b9c 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -82,12 +82,15 @@ Graph::NodeType MakeNewOperator(const std::vector> OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor) { - if (ops[iter_ops]->inputs_tensor_info().size() > MAX_INPUT_NUM) { + size_t input_tensor_size = ops[iter_ops]->inputs_tensor_info().size(); + if (ops[iter_ops]->type() == PACK) { + input_tensor_size = 1; + } + if (input_tensor_size > MAX_INPUT_NUM) { MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " input tensor num exceeds limit."; } - for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); - iter_input_tensors++) { + for (size_t iter_input_tensors = 0; iter_input_tensors < input_tensor_size; iter_input_tensors++) { if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h index 4f06299e65..c852fd3517 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -144,7 +144,8 @@ const std::map DictOpType{ {ASSIGN_ADD, OperatorType::kRecElmWiseOp}, {ASSIGN_SUB, OperatorType::kRecElmWiseOp}, {"AssignAdd", OperatorType::kRecElmWiseOp}, - {DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp}}; + {DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp}, + {PACK, OperatorType::kRecElmWiseOp}}; const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w);