|
|
|
@@ -333,6 +333,28 @@ bool StrategyFound(std::unordered_map<std::string, ValuePtr> attrs) { |
|
|
|
return !((iter == attrs.end()) || (iter->second->type_name() == NONE)); |
|
|
|
} |
|
|
|
|
|
|
|
bool HasStrategy(const FuncGraphPtr &root) { |
|
|
|
AnfNodePtr ret = root->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(ret); |
|
|
|
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); |
|
|
|
|
|
|
|
for (auto &node : all_nodes) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); |
|
|
|
auto attrs = prim->attrs(); |
|
|
|
if (StrategyFound(attrs)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsCommunicationOp(const PrimitivePtr &prim) { |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()); |
|
|
|
@@ -2225,6 +2247,14 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) |
|
|
|
// control whether use model_parallel mode |
|
|
|
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || |
|
|
|
(root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { |
|
|
|
if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) { |
|
|
|
if (HasStrategy(root)) { |
|
|
|
MS_LOG(INFO) << "strategies ignored in " << parallel_mode |
|
|
|
<< ", set_strategy() only valid in [semi_]auto_parallel."; |
|
|
|
} |
|
|
|
root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true; |
|
|
|
} |
|
|
|
|
|
|
|
return changes; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -2282,6 +2312,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) |
|
|
|
root->flags()[SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY] = true; |
|
|
|
res->results()[pipeline::kStepParallelGraph] = root; |
|
|
|
|
|
|
|
// in auto parallel mode, no need to check if stategies set |
|
|
|
root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true; |
|
|
|
|
|
|
|
(void)gettimeofday(&end_time, nullptr); |
|
|
|
uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); |
|
|
|
time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); |
|
|
|
|