|
|
@@ -28,10 +28,10 @@ |
|
|
|
|
|
|
|
|
namespace mindspore { |
|
|
namespace mindspore { |
|
|
namespace parallel { |
|
|
namespace parallel { |
|
|
void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
|
|
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, |
|
|
|
|
|
|
|
|
void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
|
|
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, |
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
const std::shared_ptr<std::vector<size_t>> index_list) { |
|
|
|
|
|
|
|
|
const std::shared_ptr<std::vector<size_t>> &index_list) { |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(eli_list); |
|
|
MS_EXCEPTION_IF_NULL(eli_list); |
|
|
MS_EXCEPTION_IF_NULL(index_list); |
|
|
MS_EXCEPTION_IF_NULL(index_list); |
|
|
@@ -140,10 +140,24 @@ std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &gr |
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const size_t iter_graph, const size_t iter_ops) { |
|
|
const size_t iter_graph, const size_t iter_ops) { |
|
|
std::vector<std::vector<int32_t>> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); |
|
|
std::vector<std::vector<int32_t>> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); |
|
|
strategies[0][0] = strategies[0][1]; |
|
|
|
|
|
strategies[0][1] = 1; |
|
|
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w; |
|
|
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int32_t axis = -1; |
|
|
|
|
|
auto iter = ops[iter_ops]->attrs().find(AXIS); |
|
|
|
|
|
if (iter != ops[iter_ops]->attrs().end()) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(iter->second); |
|
|
|
|
|
if (iter->second->isa<Int32Imm>()) { |
|
|
|
|
|
axis = iter->second->cast<Int32ImmPtr>()->value(); |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int."; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (axis == -1) { |
|
|
|
|
|
strategies[0][0] = strategies[0][1]; |
|
|
|
|
|
strategies[0][1] = 1; |
|
|
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w; |
|
|
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::vector<int32_t> s_empty = {}; |
|
|
std::vector<int32_t> s_empty = {}; |
|
|
strategies.push_back(s_empty); |
|
|
strategies.push_back(s_empty); |
|
|
strategies.push_back(s_empty); |
|
|
strategies.push_back(s_empty); |
|
|
@@ -221,7 +235,7 @@ std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Gr |
|
|
} else if (output_size == 0) { |
|
|
} else if (output_size == 0) { |
|
|
s = {}; |
|
|
s = {}; |
|
|
} else { |
|
|
} else { |
|
|
MS_LOG(ERROR) << "Tensor's output size is unexcepted."; |
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's output size is unexcepted."; |
|
|
} |
|
|
} |
|
|
strategies.push_back(s); |
|
|
strategies.push_back(s); |
|
|
} |
|
|
} |
|
|
@@ -241,7 +255,7 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr |
|
|
StrategyPtr origin_strategy = ops[iter_ops]->strategy(); |
|
|
StrategyPtr origin_strategy = ops[iter_ops]->strategy(); |
|
|
std::vector<std::vector<int32_t>> strategies; |
|
|
std::vector<std::vector<int32_t>> strategies; |
|
|
size_t max_device_num = g_device_manager->DeviceNum(); |
|
|
size_t max_device_num = g_device_manager->DeviceNum(); |
|
|
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; |
|
|
|
|
|
|
|
|
size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; |
|
|
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { |
|
|
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { |
|
|
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { |
|
|
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { |
|
|
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; |
|
|
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; |
|
|
@@ -256,8 +270,10 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr |
|
|
} else { |
|
|
} else { |
|
|
s.push_back(1); |
|
|
s.push_back(1); |
|
|
} |
|
|
} |
|
|
|
|
|
} else if (input_size == 0) { |
|
|
|
|
|
s = {}; |
|
|
} else { |
|
|
} else { |
|
|
MS_LOG(ERROR) << "Tensor's shape is unknown."; |
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
strategies.push_back(s); |
|
|
strategies.push_back(s); |
|
|
@@ -304,13 +320,13 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> & |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph, |
|
|
|
|
|
|
|
|
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph, |
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const std::shared_ptr<std::vector<size_t>> index_list) { |
|
|
|
|
|
|
|
|
const std::shared_ptr<std::vector<size_t>> &index_list) { |
|
|
for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { |
|
|
for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { |
|
|
std::vector<std::vector<int32_t>> strategies; |
|
|
std::vector<std::vector<int32_t>> strategies; |
|
|
size_t iter_graph = index_list->at(iter_ops); |
|
|
size_t iter_graph = index_list->at(iter_ops); |
|
|
if (iter_graph != SIZE_MAX) { |
|
|
|
|
|
|
|
|
if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) { |
|
|
strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); |
|
|
strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); |
|
|
} |
|
|
} |
|
|
StrategyPtr sp = std::make_shared<Strategy>(0, strategies); |
|
|
StrategyPtr sp = std::make_shared<Strategy>(0, strategies); |
|
|
@@ -335,7 +351,7 @@ size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> & |
|
|
return incoming_op_index; |
|
|
return incoming_op_index; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph, |
|
|
|
|
|
|
|
|
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &graph, |
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const size_t iter_ops, const size_t iter_graph) { |
|
|
const size_t iter_ops, const size_t iter_graph) { |
|
|
std::vector<int32_t> s; |
|
|
std::vector<int32_t> s; |
|
|
@@ -354,8 +370,10 @@ std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Gr |
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c); |
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c); |
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); |
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); |
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); |
|
|
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); |
|
|
|
|
|
} else if (input_stra_dim == 0) { |
|
|
|
|
|
s = {}; |
|
|
} else { |
|
|
} else { |
|
|
MS_LOG(ERROR) << "Tensor's shape is unknown."; |
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; |
|
|
} |
|
|
} |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
@@ -365,7 +383,8 @@ std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Gr |
|
|
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const size_t incoming_op_index) { |
|
|
const size_t incoming_op_index) { |
|
|
std::vector<int32_t> s; |
|
|
std::vector<int32_t> s; |
|
|
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2) { |
|
|
|
|
|
|
|
|
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || |
|
|
|
|
|
ops[incoming_op_index]->type() == TRANSPOSE) { |
|
|
return s; |
|
|
return s; |
|
|
} |
|
|
} |
|
|
auto strategy = ops[incoming_op_index]->selected_strategy(); |
|
|
auto strategy = ops[incoming_op_index]->selected_strategy(); |
|
|
@@ -433,13 +452,23 @@ std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shar |
|
|
return s_Squeeze; |
|
|
return s_Squeeze; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) { |
|
|
|
|
|
bool keepdims = false; |
|
|
|
|
|
auto keep_dims_iter = ops[iter_ops]->attrs().find(KEEP_DIMS); |
|
|
|
|
|
if (keep_dims_iter == ops[iter_ops]->attrs().end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr keep_dims."; |
|
|
|
|
|
} |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(keep_dims_iter->second); |
|
|
|
|
|
if (!keep_dims_iter->second->isa<BoolImm>()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Keep_dims is not a bool."; |
|
|
|
|
|
} |
|
|
|
|
|
keepdims = keep_dims_iter->second->cast<BoolImmPtr>()->value(); |
|
|
|
|
|
return keepdims; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) { |
|
|
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) { |
|
|
std::vector<int32_t> dim_list; |
|
|
std::vector<int32_t> dim_list; |
|
|
bool keep_dims; |
|
|
|
|
|
if (!ops[iter_ops]->attrs().find(KEEP_DIMS)->second->isa<BoolImm>()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Parameter keep_dims is not a boolean value." << std::endl; |
|
|
|
|
|
} |
|
|
|
|
|
keep_dims = ops[iter_ops]->attrs().find(KEEP_DIMS)->second->cast<BoolImmPtr>()->value(); |
|
|
|
|
|
|
|
|
bool keep_dims = GetKeepDims(ops, iter_ops); |
|
|
if (keep_dims != false) { |
|
|
if (keep_dims != false) { |
|
|
return dim_list; |
|
|
return dim_list; |
|
|
} |
|
|
} |
|
|
@@ -485,6 +514,62 @@ std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::share |
|
|
return s_Reduce; |
|
|
return s_Reduce; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<int32_t> GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) { |
|
|
|
|
|
std::vector<int32_t> dim_list; |
|
|
|
|
|
auto iter = ops[iter_ops]->attrs().find(AXIS); |
|
|
|
|
|
if (iter == ops[iter_ops]->attrs().end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis."; |
|
|
|
|
|
} |
|
|
|
|
|
auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(iter->second); |
|
|
|
|
|
if (iter->second->isa<ValueTuple>()) { |
|
|
|
|
|
auto attr_axis = GetValue<std::vector<int>>(iter->second); |
|
|
|
|
|
if (attr_axis.empty()) { |
|
|
|
|
|
for (size_t i = 0; i < input_dim; ++i) { |
|
|
|
|
|
dim_list.push_back(SizeToInt(i)); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
for (auto &axis : attr_axis) { |
|
|
|
|
|
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} else if (iter->second->isa<Int32Imm>()) { |
|
|
|
|
|
int axis = GetValue<int>(iter->second); |
|
|
|
|
|
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Axis type is invalid."; |
|
|
|
|
|
} |
|
|
|
|
|
return dim_list; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<int32_t> ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
|
|
const size_t incoming_op_index, std::vector<int32_t> s) { |
|
|
|
|
|
bool keepdims = GetKeepDims(ops, incoming_op_index); |
|
|
|
|
|
if (keepdims) { |
|
|
|
|
|
return s; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<int32_t> s_Arg; |
|
|
|
|
|
std::vector<int32_t> axis_list; |
|
|
|
|
|
for (size_t i = 0; i < s.size(); i++) { |
|
|
|
|
|
axis_list.push_back(i); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto dim_list = GetDimListFromAttrs(ops, incoming_op_index); |
|
|
|
|
|
for (auto axis : dim_list) { |
|
|
|
|
|
auto it = find(axis_list.begin(), axis_list.end(), axis); |
|
|
|
|
|
if (it == axis_list.end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; |
|
|
|
|
|
} |
|
|
|
|
|
axis_list.erase(it); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < (size_t)axis_list.size(); i++) { |
|
|
|
|
|
s_Arg.push_back(s[axis_list[i]]); |
|
|
|
|
|
} |
|
|
|
|
|
return s_Arg; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const size_t iter_ops, const size_t incoming_op_index) { |
|
|
const size_t iter_ops, const size_t incoming_op_index) { |
|
|
std::vector<int32_t> s; |
|
|
std::vector<int32_t> s; |
|
|
@@ -497,6 +582,9 @@ std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::sh |
|
|
ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { |
|
|
ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { |
|
|
s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); |
|
|
s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); |
|
|
} |
|
|
} |
|
|
|
|
|
if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) { |
|
|
|
|
|
s = ModifyStrategyIfArgIncoming(ops, incoming_op_index, s); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
return s; |
|
|
return s; |
|
|
} |
|
|
} |
|
|
@@ -551,11 +639,11 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect |
|
|
return stra; |
|
|
return stra; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> graph, |
|
|
|
|
|
|
|
|
void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph, |
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
const std::shared_ptr<std::vector<size_t>> index_list, |
|
|
|
|
|
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) { |
|
|
|
|
|
|
|
|
const std::shared_ptr<std::vector<size_t>> &index_list, |
|
|
|
|
|
const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) { |
|
|
if (no_stra_op_list->size() == 0) { |
|
|
if (no_stra_op_list->size() == 0) { |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
@@ -624,7 +712,8 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh |
|
|
std::vector<int32_t> s; |
|
|
std::vector<int32_t> s; |
|
|
if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || |
|
|
if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || |
|
|
ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || |
|
|
ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || |
|
|
ops[iter_ops]->type() == GATHERV2) { |
|
|
|
|
|
|
|
|
ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE || |
|
|
|
|
|
ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) { |
|
|
return s; |
|
|
return s; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -656,7 +745,7 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh |
|
|
|
|
|
|
|
|
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) { |
|
|
|
|
|
|
|
|
const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) { |
|
|
if (no_stra_op_list->size() == 0) { |
|
|
if (no_stra_op_list->size() == 0) { |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
@@ -686,16 +775,16 @@ void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_pt |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> graph, |
|
|
|
|
|
|
|
|
void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> &graph, |
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
const std::vector<std::vector<std::string>> &input_tensor_names, |
|
|
const std::shared_ptr<std::vector<size_t>> index_list, |
|
|
|
|
|
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) { |
|
|
|
|
|
|
|
|
const std::shared_ptr<std::vector<size_t>> &index_list, |
|
|
|
|
|
const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) { |
|
|
if (no_stra_op_list->size() == 0) { |
|
|
if (no_stra_op_list->size() == 0) { |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
size_t no_stra_op_list_size; |
|
|
|
|
|
|
|
|
size_t no_stra_op_list_size = no_stra_op_list->size(); |
|
|
do { |
|
|
do { |
|
|
no_stra_op_list_size = no_stra_op_list->size(); |
|
|
no_stra_op_list_size = no_stra_op_list->size(); |
|
|
GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); |
|
|
GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); |
|
|
|