|
|
|
@@ -412,16 +412,11 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector |
|
|
|
MS_EXCEPTION_IF_NULL(ops[iter_ops]); |
|
|
|
|
|
|
|
auto type = ops[iter_ops]->type(); |
|
|
|
auto idx = DictOpType.find(type); |
|
|
|
if (idx == DictOpType.end()) { |
|
|
|
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); |
|
|
|
} |
|
|
|
|
|
|
|
if (type == MATMUL) { |
|
|
|
return PrepareMatMul(graph, ops, iter_graph, iter_ops); |
|
|
|
} else if (type == ONEHOT) { |
|
|
|
return PrepareOneHot(graph, ops, iter_graph, iter_ops); |
|
|
|
} else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { |
|
|
|
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset")) { |
|
|
|
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); |
|
|
|
} else { |
|
|
|
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); |
|
|
|
|