|
|
|
@@ -28,52 +28,56 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace parallel { |
|
|
|
void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops) { |
|
|
|
void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, |
|
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
|
const std::shared_ptr<std::vector<size_t>> index_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(eli_list); |
|
|
|
MS_EXCEPTION_IF_NULL(index_list); |
|
|
|
GeneratePartitionedOperatorStrategy(graph, ops, index_list); |
|
|
|
std::shared_ptr<std::vector<size_t>> no_stra_op_list(new std::vector<size_t>); |
|
|
|
GenerateEliminatedOperatorStrategyForward(graph, ops, eli_list, input_tensor_names, index_list, no_stra_op_list); |
|
|
|
GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { |
|
|
|
std::vector<std::vector<int32_t>> stra; |
|
|
|
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { |
|
|
|
stra.push_back(PrepareStrategy(graph, ops, iter_ops, iter_op_inputs)); |
|
|
|
} |
|
|
|
// OneHot's scalar parameters were removed by entire_costgraph, we had to complete them. |
|
|
|
if (ops[iter_ops]->type() == ONEHOT) { |
|
|
|
std::vector<int32_t> s_Onehot = {}; |
|
|
|
stra.push_back(s_Onehot); |
|
|
|
stra.push_back(s_Onehot); |
|
|
|
std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph, |
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const size_t iter_graph, const size_t iter_ops) { |
|
|
|
std::vector<std::vector<int32_t>> strategies; |
|
|
|
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { |
|
|
|
std::vector<int32_t> s; |
|
|
|
auto attrs = ops[iter_ops]->attrs(); |
|
|
|
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value(); |
|
|
|
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value(); |
|
|
|
if (transpose_a && (iter_op_inputs == 0)) { |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
} else if (transpose_b && (iter_op_inputs == 1)) { |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
} else { |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
} |
|
|
|
StrategyPtr sp = std::make_shared<Strategy>(0, stra); |
|
|
|
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); |
|
|
|
strategies.push_back(s); |
|
|
|
} |
|
|
|
return strategies; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph, |
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes, |
|
|
|
const size_t iter_op_inputs) { |
|
|
|
std::vector<int32_t> s; |
|
|
|
auto attrs = ops[iter_nodes]->attrs(); |
|
|
|
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value(); |
|
|
|
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value(); |
|
|
|
if (transpose_a && (iter_op_inputs == 0)) { |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
} else if (transpose_b && (iter_op_inputs == 1)) { |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
} else { |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
} |
|
|
|
return s; |
|
|
|
std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const size_t iter_ops) { |
|
|
|
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(ops, iter_ops); |
|
|
|
strategies[1][0] = strategies[0][0]; |
|
|
|
return strategies; |
|
|
|
} |
|
|
|
|
|
|
|
// std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
// const size_t iter_ops) { |
|
|
|
// std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(ops, iter_ops); |
|
|
|
// strategies[1][0] = strategies[0][0]; |
|
|
|
// return strategies; |
|
|
|
// } |
|
|
|
|
|
|
|
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const size_t iter_ops, std::vector<int32_t> s) { |
|
|
|
std::vector<std::vector<int32_t>> strategies; |
|
|
|
@@ -99,9 +103,9 @@ std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s) { |
|
|
|
return strategies; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int32_t> MakeRecSearchStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const std::shared_ptr<Graph> &graph, const size_t iter_ops, |
|
|
|
const size_t iter_op_inputs) { |
|
|
|
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, |
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const size_t iter_graph, const size_t iter_ops) { |
|
|
|
if (ops.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Operators is empty."; |
|
|
|
} |
|
|
|
@@ -111,35 +115,46 @@ std::vector<int32_t> MakeRecSearchStrategy(const std::vector<std::shared_ptr<Ope |
|
|
|
|
|
|
|
StrategyPtr origin_strategy = ops[iter_ops]->strategy(); |
|
|
|
|
|
|
|
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; |
|
|
|
} |
|
|
|
|
|
|
|
// size_t output_size = ops[iter_ops]->outputs_tensor_info()[0].shape().size(); |
|
|
|
size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); |
|
|
|
|
|
|
|
std::vector<int32_t> s = {}; |
|
|
|
if (output_size == 4) { |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_n)); |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_c)); |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
} else if (output_size == 2) { |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
} else if (output_size == 1) { |
|
|
|
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
} else if (output_size == 0) { |
|
|
|
return s; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Tensor's output size is unexcepted."; |
|
|
|
} |
|
|
|
std::vector<std::vector<int32_t>> strategies; |
|
|
|
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { |
|
|
|
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; |
|
|
|
} |
|
|
|
|
|
|
|
return s; |
|
|
|
// size_t output_size = ops[iter_ops]->outputs_tensor_info()[0].shape().size(); |
|
|
|
size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); |
|
|
|
|
|
|
|
std::vector<int32_t> s; |
|
|
|
if (output_size == 4) { |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_n)); |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c)); |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
} else if (output_size == 2) { |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
} else if (output_size == 1) { |
|
|
|
s.push_back( |
|
|
|
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); |
|
|
|
} else if (output_size == 0) { |
|
|
|
s = {}; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Tensor's output size is unexcepted."; |
|
|
|
} |
|
|
|
|
|
|
|
strategies.push_back(s); |
|
|
|
} |
|
|
|
return strategies; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const size_t iter_ops, const size_t iter_op_inputs) { |
|
|
|
std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const size_t iter_ops) { |
|
|
|
if (ops.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Operators is empty."; |
|
|
|
} |
|
|
|
@@ -149,28 +164,32 @@ std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr< |
|
|
|
|
|
|
|
StrategyPtr origin_strategy = ops[iter_ops]->strategy(); |
|
|
|
|
|
|
|
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; |
|
|
|
} |
|
|
|
std::vector<std::vector<int32_t>> strategies; |
|
|
|
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { |
|
|
|
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int32_t> s; |
|
|
|
size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); |
|
|
|
for (size_t dim = 0; dim < input_size; dim++) { |
|
|
|
if (dim == 0 && input_size == 4) { |
|
|
|
size_t max_device_num = g_device_manager->DeviceNum(); |
|
|
|
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; |
|
|
|
s.push_back(std::min(max_device_num, target_tensor_batch)); |
|
|
|
} else { |
|
|
|
s.push_back(1); |
|
|
|
std::vector<int32_t> s; |
|
|
|
size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); |
|
|
|
for (size_t dim = 0; dim < input_size; dim++) { |
|
|
|
if (dim == 0 && input_size == 4) { |
|
|
|
size_t max_device_num = g_device_manager->DeviceNum(); |
|
|
|
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; |
|
|
|
s.push_back(std::min(max_device_num, target_tensor_batch)); |
|
|
|
} else { |
|
|
|
s.push_back(1); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return s; |
|
|
|
strategies.push_back(s); |
|
|
|
} |
|
|
|
return strategies; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph, |
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, |
|
|
|
const size_t iter_op_inputs) { |
|
|
|
std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &graph, |
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const size_t iter_graph, const size_t iter_ops) { |
|
|
|
if (ops.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: Operators is empty."; |
|
|
|
} |
|
|
|
@@ -179,19 +198,35 @@ std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph, |
|
|
|
} |
|
|
|
|
|
|
|
auto type = ops[iter_ops]->type(); |
|
|
|
if (type == VIRTUAL_DATA_SET) { |
|
|
|
return PrepareVirtualDataset(ops, iter_ops); |
|
|
|
} |
|
|
|
auto idx = DictOpType.find(type); |
|
|
|
if (idx == DictOpType.end()) { |
|
|
|
return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs); |
|
|
|
return MakeDataParallelStrategy(ops, iter_ops); |
|
|
|
} |
|
|
|
|
|
|
|
if (type == MATMUL) { |
|
|
|
return PrepareMatMul(graph, ops, iter_ops, iter_op_inputs); |
|
|
|
return PrepareMatMul(graph, ops, iter_graph, iter_ops); |
|
|
|
} else if (type == RESHAPE) { |
|
|
|
return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs); |
|
|
|
} else if (type == DIV || type == SUB || type == MUL) { |
|
|
|
return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs); |
|
|
|
return MakeDataParallelStrategy(ops, iter_ops); |
|
|
|
} else { |
|
|
|
return MakeRecSearchStrategy(ops, graph, iter_ops, iter_op_inputs); |
|
|
|
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph, |
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const std::shared_ptr<std::vector<size_t>> index_list) { |
|
|
|
for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { |
|
|
|
std::vector<std::vector<int32_t>> strategies; |
|
|
|
size_t iter_graph = index_list->at(iter_ops); |
|
|
|
if (iter_graph == SIZE_MAX) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); |
|
|
|
StrategyPtr sp = std::make_shared<Strategy>(0, strategies); |
|
|
|
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -353,6 +388,25 @@ std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::share |
|
|
|
return s_Reduce; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const int incoming_op_index, const size_t iter_ops, |
|
|
|
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) { |
|
|
|
std::vector<int32_t> s; |
|
|
|
s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index); |
|
|
|
if (s.size() != 0) { |
|
|
|
if (ops[incoming_op_index]->type() == SQUEEZE) { |
|
|
|
s = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, s); |
|
|
|
} |
|
|
|
if (ops[incoming_op_index]->type() == REDUCE_SUM || ops[incoming_op_index]->type() == REDUCE_MAX || |
|
|
|
ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { |
|
|
|
s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); |
|
|
|
} |
|
|
|
} else { |
|
|
|
no_stra_op_list->push_back(iter_ops); |
|
|
|
} |
|
|
|
return s; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const size_t iter_ops, std::vector<int32_t> s) { |
|
|
|
std::vector<int32_t> s_empty = {}; |
|
|
|
@@ -389,6 +443,33 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect |
|
|
|
return stra; |
|
|
|
} |
|
|
|
|
|
|
|
void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> graph, |
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, |
|
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
|
const std::shared_ptr<std::vector<size_t>> index_list, |
|
|
|
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) { |
|
|
|
for (int eli_index = eli_list->size() - 1; eli_index >= 0; eli_index--) { |
|
|
|
size_t iter_ops = eli_list->at(eli_index)[0]; |
|
|
|
std::vector<std::vector<int32_t>> stra; |
|
|
|
std::vector<int32_t> s; |
|
|
|
int incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); |
|
|
|
if (incoming_op_index != -1) { |
|
|
|
auto iter_graph = index_list->at(incoming_op_index); |
|
|
|
if (iter_graph != SIZE_MAX) { |
|
|
|
s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph); |
|
|
|
} else { |
|
|
|
s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index, iter_ops, no_stra_op_list); |
|
|
|
} |
|
|
|
} else { |
|
|
|
no_stra_op_list->push_back(iter_ops); |
|
|
|
} |
|
|
|
stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); |
|
|
|
StrategyPtr sp = std::make_shared<Strategy>(0, stra); |
|
|
|
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const size_t iter_ops, std::vector<int32_t> s) { |
|
|
|
std::vector<int32_t> s_Squeeze; |
|
|
|
@@ -427,5 +508,47 @@ std::vector<int32_t> ModifyStrategyIfReduceOutgoing(const std::vector<std::share |
|
|
|
} |
|
|
|
return s_Reduce; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
|
const size_t iter_ops) { |
|
|
|
std::vector<int32_t> s; |
|
|
|
bool found = false; |
|
|
|
for (size_t i = 0; i < (size_t)input_tensor_names.size(); i++) { |
|
|
|
for (size_t j = 1; j < (size_t)input_tensor_names[i].size(); j++) { |
|
|
|
if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0]) { |
|
|
|
for (size_t k = 0; k < ops[i]->selected_strategy()->GetInputDim()[j - 1].size(); ++k) { |
|
|
|
s.push_back(ops[i]->selected_strategy()->GetInputDim()[j - 1][k]); |
|
|
|
} |
|
|
|
found = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (found) break; |
|
|
|
} |
|
|
|
return s; |
|
|
|
} |
|
|
|
|
|
|
|
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
|
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(no_stra_op_list); |
|
|
|
for (int iter_list = no_stra_op_list->size() - 1; iter_list >= 0; iter_list--) { |
|
|
|
auto iter_ops = no_stra_op_list->at(iter_list); |
|
|
|
std::vector<std::vector<int32_t>> stra; |
|
|
|
std::vector<int32_t> s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); |
|
|
|
if (ops[iter_ops]->type() == SQUEEZE) { |
|
|
|
s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); |
|
|
|
} |
|
|
|
if (ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MAX || |
|
|
|
ops[iter_ops]->type() == REDUCE_MIN || ops[iter_ops]->type() == REDUCE_MEAN) { |
|
|
|
s = ModifyStrategyIfReduceOutgoing(ops, iter_ops, s); |
|
|
|
} |
|
|
|
stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); |
|
|
|
StrategyPtr sp = std::make_shared<Strategy>(0, stra); |
|
|
|
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace parallel |
|
|
|
} // namespace mindspore |