|
|
|
@@ -368,14 +368,19 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, |
|
|
|
for (size_t dim = 0; dim < input_size; dim++) { |
|
|
|
if (input_size == 1 || input_size == 2 || input_size == 4) { |
|
|
|
if (dim == 0) { |
|
|
|
s.push_back(std::min(max_device_num, target_tensor_batch)); |
|
|
|
// Currently GPU version does not support partitioning ‘FusedBatchNormEx’ in its param tensors. |
|
|
|
if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) { |
|
|
|
s.push_back(1); |
|
|
|
} else { |
|
|
|
s.push_back(std::min(max_device_num, target_tensor_batch)); |
|
|
|
} |
|
|
|
} else { |
|
|
|
s.push_back(1); |
|
|
|
} |
|
|
|
} else if (input_size == 0) { |
|
|
|
s = {}; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; |
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected."; |
|
|
|
} |
|
|
|
} |
|
|
|
strategies.push_back(s); |
|
|
|
@@ -416,6 +421,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector |
|
|
|
return PrepareMatMul(graph, ops, iter_graph, iter_ops); |
|
|
|
} else if (type == ONEHOT) { |
|
|
|
return PrepareOneHot(graph, ops, iter_graph, iter_ops); |
|
|
|
} else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { |
|
|
|
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); |
|
|
|
} else { |
|
|
|
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); |
|
|
|
} |
|
|
|
|