|
|
|
@@ -326,6 +326,14 @@ Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, |
|
|
|
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
s.push_back( |
|
|
|
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
} else if (output_size == 3) { |
|
|
|
// Experimental support for 3D data. |
|
|
|
s.push_back( |
|
|
|
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c)); |
|
|
|
s.push_back( |
|
|
|
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
s.push_back( |
|
|
|
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
} else if (output_size == 2) { |
|
|
|
s.push_back( |
|
|
|
static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
@@ -366,7 +374,8 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, |
|
|
|
Dimensions s; |
|
|
|
size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); |
|
|
|
for (size_t dim = 0; dim < input_size; dim++) { |
|
|
|
if (input_size == 1 || input_size == 2 || input_size == 4) { |
|
|
|
// Experimental support for 3D data (input_size == 3). |
|
|
|
if (input_size >= 1 && input_size <= 4) { |
|
|
|
if (dim == 0) { |
|
|
|
// Currently GPU version does not support partitioning ‘FusedBatchNormEx’ in its param tensors. |
|
|
|
if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) { |
|
|
|
@@ -385,17 +394,27 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, |
|
|
|
} |
|
|
|
strategies.push_back(s); |
|
|
|
} |
|
|
|
|
|
|
|
// Set default strategy. |
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0; |
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0; |
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0; |
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; |
|
|
|
|
|
|
|
// Update data parallel strategy. |
|
|
|
if (ops[iter_ops]->outputs_tensor_info().size() == 0) { |
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor info is empty."; |
|
|
|
} |
|
|
|
if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { |
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch); |
|
|
|
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { |
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch); |
|
|
|
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) { |
|
|
|
// Experimental support for 3D data. |
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0 / std::min(max_device_num, target_tensor_batch); |
|
|
|
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { |
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor shape is unexpected."; |
|
|
|
} |
|
|
|
|
|
|
|
return strategies; |
|
|
|
@@ -416,7 +435,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector |
|
|
|
return PrepareMatMul(graph, ops, iter_graph, iter_ops); |
|
|
|
} else if (type == ONEHOT) { |
|
|
|
return PrepareOneHot(graph, ops, iter_graph, iter_ops); |
|
|
|
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || (type == "Dropout")) { |
|
|
|
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || |
|
|
|
(type == "FusedBatchNormEx") || (type == "Dropout")) { |
|
|
|
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); |
|
|
|
} else { |
|
|
|
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); |
|
|
|
@@ -468,6 +488,11 @@ Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &grap |
|
|
|
} else if (input_stra_dim == 2) { |
|
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); |
|
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); |
|
|
|
} else if (input_stra_dim == 3) { |
|
|
|
// Experimental support for 3D data. |
|
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c); |
|
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); |
|
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); |
|
|
|
} else if (input_stra_dim == 4) { |
|
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_n); |
|
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c); |
|
|
|
|