Browse Source

add new op pack

tags/v1.1.0
sheng 5 years ago
parent
commit
40a8064968
2 changed files with 8 additions and 4 deletions
  1. +6
    -3
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc
  2. +2
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h

+ 6
- 3
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc View File

@@ -82,12 +82,15 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>>
OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Graph::NodeType NewTensor) {
if (ops[iter_ops]->inputs_tensor_info().size() > MAX_INPUT_NUM) {
size_t input_tensor_size = ops[iter_ops]->inputs_tensor_info().size();
if (ops[iter_ops]->type() == PACK) {
input_tensor_size = 1;
}
if (input_tensor_size > MAX_INPUT_NUM) {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " input tensor num exceeds limit.";
}
for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size();
iter_input_tensors++) {
for (size_t iter_input_tensors = 0; iter_input_tensors < input_tensor_size; iter_input_tensors++) {
if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) {
NewTensor.apply.arguments[iter_input_tensors] =
MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],


+ 2
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h View File

@@ -144,7 +144,8 @@ const std::map<std::string, OperatorType> DictOpType{
{ASSIGN_ADD, OperatorType::kRecElmWiseOp},
{ASSIGN_SUB, OperatorType::kRecElmWiseOp},
{"AssignAdd", OperatorType::kRecElmWiseOp},
{DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp}};
{DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp},
{PACK, OperatorType::kRecElmWiseOp}};

const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w);



Loading…
Cancel
Save