Browse Source

!7091 [AutoParallel] add support for FusedBatchNormEx

Merge pull request !7091 from Chong/FusedBatchNormEx
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d6032dfbbb
1 changed files with 9 additions and 2 deletions
  1. +9
    -2
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc

+ 9
- 2
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc View File

@@ -368,14 +368,19 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
for (size_t dim = 0; dim < input_size; dim++) {
if (input_size == 1 || input_size == 2 || input_size == 4) {
if (dim == 0) {
s.push_back(std::min(max_device_num, target_tensor_batch));
// Currently GPU version does not support partitioning ‘FusedBatchNormEx’ in its param tensors.
if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) {
s.push_back(1);
} else {
s.push_back(std::min(max_device_num, target_tensor_batch));
}
} else {
s.push_back(1);
}
} else if (input_size == 0) {
s = {};
} else {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown.";
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
}
}
strategies.push_back(s);
@@ -416,6 +421,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
} else if (type == ONEHOT) {
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
} else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
} else {
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
}


Loading…
Cancel
Save