| @@ -368,14 +368,19 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, | |||||
| for (size_t dim = 0; dim < input_size; dim++) { | for (size_t dim = 0; dim < input_size; dim++) { | ||||
| if (input_size == 1 || input_size == 2 || input_size == 4) { | if (input_size == 1 || input_size == 2 || input_size == 4) { | ||||
| if (dim == 0) { | if (dim == 0) { | ||||
| s.push_back(std::min(max_device_num, target_tensor_batch)); | |||||
| // Currently GPU version does not support partitioning ‘FusedBatchNormEx’ in its param tensors. | |||||
| if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) { | |||||
| s.push_back(1); | |||||
| } else { | |||||
| s.push_back(std::min(max_device_num, target_tensor_batch)); | |||||
| } | |||||
| } else { | } else { | ||||
| s.push_back(1); | s.push_back(1); | ||||
| } | } | ||||
| } else if (input_size == 0) { | } else if (input_size == 0) { | ||||
| s = {}; | s = {}; | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; | |||||
| MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected."; | |||||
| } | } | ||||
| } | } | ||||
| strategies.push_back(s); | strategies.push_back(s); | ||||
| @@ -416,6 +421,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector | |||||
| return PrepareMatMul(graph, ops, iter_graph, iter_ops); | return PrepareMatMul(graph, ops, iter_graph, iter_ops); | ||||
| } else if (type == ONEHOT) { | } else if (type == ONEHOT) { | ||||
| return PrepareOneHot(graph, ops, iter_graph, iter_ops); | return PrepareOneHot(graph, ops, iter_graph, iter_ops); | ||||
| } else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { | |||||
| return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); | |||||
| } else { | } else { | ||||
| return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); | return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); | ||||
| } | } | ||||