浏览代码

Support Dropout stra; Add parser's input/output tensor num check

tags/v1.1.0
sheng 5 年前
父节点
当前提交
83e627dd5e
共有 2 个文件被更改,包括 9 次插入1 次删除
  1. +1
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
  2. +8
    -0
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc

+ 1
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc 查看文件

@@ -416,7 +416,7 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
} else if (type == ONEHOT) {
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset")) {
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || (type == "Dropout")) {
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
} else {
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);


+ 8
- 0
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc 查看文件

@@ -53,6 +53,10 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>>
NewOp.apply.op_type = DictOpType.at(op_type);
}
if (ops[iter_ops]->outputs_tensor_info().size() == 0) {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor info is empty.";
}
if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) {
NewOp.tensor_parm = MakeTensor(
ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
@@ -74,6 +78,10 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>>
OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Graph::NodeType NewTensor) {
if (ops[iter_ops]->inputs_tensor_info().size() > MAX_INPUT_NUM) {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " input tensor num exceeds limit.";
}
for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size();
iter_input_tensors++) {
if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) {


正在加载...
取消
保存