|
|
|
@@ -2294,6 +2294,7 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const |
|
|
|
} |
|
|
|
|
|
|
|
void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) { |
|
|
|
MS_EXCEPTION_IF_NULL(split_index); |
|
|
|
if (split_index->empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -2313,6 +2314,15 @@ void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split |
|
|
|
void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id(); |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode); |
|
|
|
auto parallel_context = parallel::ParallelContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(parallel_context); |
|
|
|
auto parallel_mode = parallel_context->parallel_mode(); |
|
|
|
if (!pynative_mode || parallel_mode != parallel::DATA_PARALLEL) { |
|
|
|
return; |
|
|
|
} |
|
|
|
SetGraphBpropAttr(graph); |
|
|
|
|
|
|
|
if (!graph->is_bprop()) { |
|
|
|
|