|
|
@@ -948,10 +948,12 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) { |
|
|
return target_op; |
|
|
return target_op; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void CostGraph::CreateTriangleEliminationSubCostListForIdentity( |
|
|
|
|
|
StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, StrategyPtr right_op_stra, const CostPtr& right_op_cost, |
|
|
|
|
|
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost, |
|
|
|
|
|
const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) { |
|
|
|
|
|
|
|
|
void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, |
|
|
|
|
|
StrategyPtr right_op_stra, const CostPtr& right_op_cost, |
|
|
|
|
|
const CostPtrList& elimi_op_clist, |
|
|
|
|
|
const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost, |
|
|
|
|
|
const CostPtrList& left_node_clist_origin, |
|
|
|
|
|
CostPtrList* left_node_clist_new) { |
|
|
MS_EXCEPTION_IF_NULL(right_edge_cost); |
|
|
MS_EXCEPTION_IF_NULL(right_edge_cost); |
|
|
MS_EXCEPTION_IF_NULL(right_op_cost); |
|
|
MS_EXCEPTION_IF_NULL(right_op_cost); |
|
|
MS_EXCEPTION_IF_NULL(left_node_clist_new); |
|
|
MS_EXCEPTION_IF_NULL(left_node_clist_new); |
|
|
@@ -985,93 +987,20 @@ void CostGraph::CreateTriangleEliminationSubCostListForIdentity( |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void CostGraph::CreateTriangleEliminationSubCostListForOthers( |
|
|
|
|
|
StrategyPtr elimi_op_stra, StrategyPtr left_node_stra, StrategyPtr right_node_stra, const CostPtr& right_op_cost, |
|
|
|
|
|
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost, |
|
|
|
|
|
const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) { |
|
|
|
|
|
CostPtr elimi_op_determined = nullptr, left_edge_determined = nullptr, init_ele = nullptr; |
|
|
|
|
|
std::function<CostPtr(CostPtr, const CostPtr&)> LocalCompare = [&](CostPtr init, const CostPtr& cost_x) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cost_x); |
|
|
|
|
|
if ((init == nullptr) || (cost_x->memory_cost_ < DEVICE_MEMORY_CAPACITY)) { |
|
|
|
|
|
init = cost_x; |
|
|
|
|
|
} |
|
|
|
|
|
return init; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
// Find a feasible elimi_op_clist |
|
|
|
|
|
elimi_op_determined = std::accumulate(elimi_op_clist.begin(), elimi_op_clist.end(), init_ele, LocalCompare); |
|
|
|
|
|
init_ele = nullptr; |
|
|
|
|
|
// Find a feasible left_edge_cost |
|
|
|
|
|
left_edge_determined = std::accumulate(left_edge_clist.begin(), left_edge_clist.end(), init_ele, LocalCompare); |
|
|
|
|
|
if ((elimi_op_determined == nullptr) || (left_edge_determined == nullptr)) { |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
if ((elimi_op_determined->memory_cost_ >= DEVICE_MEMORY_CAPACITY) || |
|
|
|
|
|
(left_edge_determined->memory_cost_ >= DEVICE_MEMORY_CAPACITY)) { |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (auto& left_node_cost : left_node_clist_origin) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(left_node_cost); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_op_cost); |
|
|
|
|
|
double new_memory_cost = left_node_cost->memory_cost_ + elimi_op_determined->memory_cost_ + |
|
|
|
|
|
left_edge_determined->memory_cost_ + right_edge_cost->memory_cost_ + |
|
|
|
|
|
right_op_cost->memory_cost_; |
|
|
|
|
|
double commu_cost = left_node_cost->communication_cost_ + elimi_op_determined->communication_cost_ + |
|
|
|
|
|
left_edge_determined->communication_cost_ + right_edge_cost->communication_cost_ + |
|
|
|
|
|
right_op_cost->communication_cost_; |
|
|
|
|
|
double commu_without = |
|
|
|
|
|
left_node_cost->communication_without_parameter_ + elimi_op_determined->communication_without_parameter_ + |
|
|
|
|
|
left_edge_determined->communication_without_parameter_ + right_edge_cost->communication_without_parameter_ + |
|
|
|
|
|
right_op_cost->communication_without_parameter_; |
|
|
|
|
|
auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_determined, |
|
|
|
|
|
left_edge_determined, right_edge_cost, left_node_stra, |
|
|
|
|
|
left_node_cost, right_node_stra, right_op_cost); |
|
|
|
|
|
|
|
|
|
|
|
auto new_cost = std::make_shared<Cost>(new_memory_cost, commu_cost, decision); |
|
|
|
|
|
new_cost->communication_without_parameter_ = commu_without; |
|
|
|
|
|
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); |
|
|
|
|
|
left_node_clist_new->emplace_back(std::move(new_cost)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_op, const CostPtrList& right_node_clist, |
|
|
void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_op, const CostPtrList& right_node_clist, |
|
|
const CostPtrList& right_edge_clist, const StrategyPtr& elimi_op_stra, |
|
|
const CostPtrList& right_edge_clist, const StrategyPtr& elimi_op_stra, |
|
|
const StrategyPtr& left_node_stra, const StrategyPtr& right_node_stra, |
|
|
const StrategyPtr& left_node_stra, const StrategyPtr& right_node_stra, |
|
|
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, |
|
|
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, |
|
|
const CostPtrList& left_node_clist_origin, |
|
|
const CostPtrList& left_node_clist_origin, |
|
|
CostPtrList* left_node_clist_new) { |
|
|
CostPtrList* left_node_clist_new) { |
|
|
// The reason for separately dealing with when the 'elimi_op' is 'TMPIDENTITY_INFO' or others is that |
|
|
|
|
|
// when 'elimi_op' is TMPIDENTITY_INFO, the computation is limited, while 'elimi_op' is others, the computation |
|
|
|
|
|
// may be huge |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(elimi_op); |
|
|
MS_EXCEPTION_IF_NULL(elimi_op); |
|
|
if (elimi_op->name().find(TMPIDENTITY_INFO_NAME) != std::string::npos) { |
|
|
|
|
|
for (auto& right_node_cost : right_node_clist) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_node_cost); |
|
|
|
|
|
for (auto& right_edge_cost : right_edge_clist) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_edge_cost); |
|
|
|
|
|
if ((right_node_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY) && |
|
|
|
|
|
(right_edge_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY)) { |
|
|
|
|
|
// Exact computation for TMPIDENTITY_INFO_NAME case |
|
|
|
|
|
CreateTriangleEliminationSubCostListForIdentity(elimi_op_stra, left_node_stra, right_node_stra, |
|
|
|
|
|
right_node_cost, elimi_op_clist, left_edge_clist, |
|
|
|
|
|
right_edge_cost, left_node_clist_origin, left_node_clist_new); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
for (auto& right_node_cost : right_node_clist) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_node_cost); |
|
|
|
|
|
for (auto& right_edge_cost : right_edge_clist) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_edge_cost); |
|
|
|
|
|
if ((right_node_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY) && |
|
|
|
|
|
(right_edge_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY)) { |
|
|
|
|
|
// Approximate computation for other case |
|
|
|
|
|
CreateTriangleEliminationSubCostListForOthers(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, |
|
|
|
|
|
elimi_op_clist, left_edge_clist, right_edge_cost, |
|
|
|
|
|
left_node_clist_origin, left_node_clist_new); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
for (auto& right_node_cost : right_node_clist) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_node_cost); |
|
|
|
|
|
for (auto& right_edge_cost : right_edge_clist) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_edge_cost); |
|
|
|
|
|
CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, |
|
|
|
|
|
elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin, |
|
|
|
|
|
left_node_clist_new); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|