Browse Source

!7947 [AutoParallel] handle 3D tensor in rec

Merge pull request !7947 from Chong/BERT
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ab0f9f218b
2 changed files with 29 additions and 4 deletions
  1. +28
    -3
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
  2. +1
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc

+ 28
- 3
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc View File

@@ -326,6 +326,14 @@ Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
} else if (output_size == 3) {
// Experimental support for 3D data.
s.push_back(
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c));
s.push_back(
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
} else if (output_size == 2) {
s.push_back(
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
@@ -366,7 +374,8 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
Dimensions s;
size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
for (size_t dim = 0; dim < input_size; dim++) {
if (input_size == 1 || input_size == 2 || input_size == 4) {
// Experimental support for 3D data (input_size == 3).
if (input_size >= 1 && input_size <= 4) {
if (dim == 0) {
// Currently GPU version does not support partitioning ‘FusedBatchNormEx’ in its param tensors.
if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) {
@@ -385,17 +394,27 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
}
strategies.push_back(s);
}
// Set default strategy.
graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;

// Update data parallel strategy.
if (ops[iter_ops]->outputs_tensor_info().size() == 0) {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor info is empty.";
}
if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) {
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) {
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) {
// Experimental support for 3D data.
graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0 / std::min(max_device_num, target_tensor_batch);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) {
graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch);
} else {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor shape is unexpected.";
}

return strategies;
@@ -416,7 +435,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
} else if (type == ONEHOT) {
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || (type == "Dropout")) {
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") ||
(type == "FusedBatchNormEx") || (type == "Dropout")) {
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
} else {
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
@@ -468,6 +488,11 @@ Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &grap
} else if (input_stra_dim == 2) {
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h);
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w);
} else if (input_stra_dim == 3) {
// Experimental support for 3D data.
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c);
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h);
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w);
} else if (input_stra_dim == 4) {
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_n);
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c);


+ 1
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc View File

@@ -48,7 +48,7 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>>
auto idx = DictOpType.find(op_type);
if (idx == DictOpType.end()) {
NewOp.apply.op_type = OperatorType::kRecUnkownType;
MS_LOG(INFO) << "Unknown operator type: " << op_type;
MS_LOG(INFO) << ops[iter_ops]->name() << ": Unknown operator type " << op_type;
} else {
NewOp.apply.op_type = DictOpType.at(op_type);
}


Loading…
Cancel
Save