Merge pull request !5352 from fary86/refactor_context_interfacetags/v1.0.0
| @@ -25,7 +25,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> & /*outputs*/, void *stream_ptr) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->enable_task_sink()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) { | |||
| return true; | |||
| } | |||
| if (inputs.empty() || hccl_data_type_list_.empty()) { | |||
| @@ -24,7 +24,7 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->enable_task_sink()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) { | |||
| return true; | |||
| } | |||
| if (inputs.empty() || hccl_data_type_list_.empty()) { | |||
| @@ -24,7 +24,7 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->enable_task_sink()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) { | |||
| return true; | |||
| } | |||
| if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { | |||
| @@ -25,7 +25,7 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->enable_task_sink()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) { | |||
| return true; | |||
| } | |||
| if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { | |||
| @@ -101,7 +101,8 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL) && | |||
| IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { | |||
| kernel_type = KernelType::AKG_KERNEL; | |||
| } | |||
| @@ -328,7 +328,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im | |||
| } | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool is_gpu = (context->device_target() == kGPUDevice); | |||
| bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice); | |||
| if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) { | |||
| MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) | |||
| << ", current op num: " << op_info_.size(); | |||
| @@ -249,8 +249,8 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||
| void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -262,7 +262,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| } | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | |||
| if (context_ptr->execution_mode() == kPynativeMode) { | |||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); | |||
| } else { | |||
| @@ -276,7 +276,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| AddAscendIRFusionRulesPass(ir_fusion_pm.get()); | |||
| AddAscendIRFusionPass(ir_fusion_pm.get()); | |||
| if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && | |||
| ConfigManager::GetInstance().iter_num() > 1) { | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| @@ -296,12 +297,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!context_ptr->ir_fusion_flag()) { | |||
| if (!context_ptr->get_param<bool>(MS_CTX_IR_FUSION_FLAG)) { | |||
| MS_LOG(INFO) << "IRFusion is not enable, skip"; | |||
| return; | |||
| } | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -331,8 +332,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||
| void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -367,7 +368,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||
| auto other2_pm = std::make_shared<PassManager>("other2_pm"); | |||
| other2_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| other2_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||
| if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && | |||
| ConfigManager::GetInstance().iter_num() > 1) { | |||
| other2_pm->AddPass(std::make_shared<GetnextMemcpyElimination>()); | |||
| } | |||
| other2_pm->AddPass(std::make_shared<CheckConsistency>()); | |||
| @@ -388,11 +390,11 @@ void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &ke | |||
| bool is_before_kernel_select) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->enable_graph_kernel())) { | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| return; | |||
| } | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -418,11 +420,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern | |||
| bool is_before_kernel_select) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->enable_graph_kernel())) { | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| return; | |||
| } | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -447,11 +449,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern | |||
| void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->enable_graph_kernel())) { | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| return; | |||
| } | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -473,12 +475,12 @@ void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &ke | |||
| void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!context_ptr->ir_fusion_flag()) { | |||
| if (!context_ptr->get_param<bool>(MS_CTX_IR_FUSION_FLAG)) { | |||
| MS_LOG(INFO) << "UBFusion is not enable, skip"; | |||
| return; | |||
| } | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -53,7 +53,8 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) { | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode && | |||
| !ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) { | |||
| if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { | |||
| return new_node; | |||
| } | |||
| @@ -44,7 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kPynativeMode) { | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| return RectifyKernelInfoInPynativeProcess(node); | |||
| } | |||
| if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { | |||
| @@ -33,8 +33,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||
| MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -392,7 +392,8 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { | |||
| bool IsNopNode(const AnfNodePtr &node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) { | |||
| if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice && | |||
| context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { | |||
| return false; | |||
| } | |||
| static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, | |||
| @@ -40,8 +40,8 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> | |||
| } | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -114,7 +114,7 @@ const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &gr | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->device_target() == kAscendDevice) { | |||
| if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) { | |||
| if (!CheckAttrs(strided_slice_grad)) { | |||
| MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed"; | |||
| return nullptr; | |||
| @@ -359,11 +359,11 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| if (context_ptr->save_graphs_flag()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| std::string file_path = save_graphs_path + "/after_erase_label_and_parameter.ir"; | |||
| DumpIR(file_path, root_graph.get()); | |||
| } | |||
| @@ -253,7 +253,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { | |||
| debugger_->PreExecute(graph); | |||
| } | |||
| #endif | |||
| if (ms_context->precompile_only()) { | |||
| if (ms_context->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) { | |||
| MS_LOG(INFO) << "Precompile only, stop in build kernel step"; | |||
| } else { | |||
| // alloc memory, including static memory and dynamic memory | |||
| @@ -278,8 +278,8 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { | |||
| child_graph->SetExecOrderByDefault(); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -436,7 +436,7 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kGraphMode) { | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | |||
| if (raise_precision_count > 0) { | |||
| MS_LOG(WARNING) << "There has " << raise_precision_count | |||
| << " node/nodes used raise precision to selected the kernel!"; | |||
| @@ -481,8 +481,8 @@ void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -601,11 +601,11 @@ void AscendSession::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) | |||
| #ifdef ENABLE_DUMP_IR | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| if (!save_graphs) { | |||
| return; | |||
| } | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -733,7 +733,7 @@ void AscendSession::MergeGraphExecOrder() { | |||
| if (graph_order.size() > 1) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!context_ptr->enable_task_sink()) { | |||
| if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) { | |||
| MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!"; | |||
| } | |||
| } | |||
| @@ -920,8 +920,8 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs) { | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| @@ -947,7 +947,7 @@ void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kGraphMode) { | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | |||
| if (raise_precision_count > 0) { | |||
| MS_LOG(WARNING) << "There are " << raise_precision_count | |||
| << " node/nodes used raise precision to selected the kernel!"; | |||
| @@ -992,8 +992,8 @@ void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph, | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs) { | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| @@ -76,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | |||
| if (!CheckInModeBlackList(kernel_graph) && context_ptr->execution_mode() != kPynativeMode) { | |||
| if (!CheckInModeBlackList(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| pm->AddPass(std::make_shared<opt::BatchNormReluFusion>()); | |||
| pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>()); | |||
| pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>()); | |||
| @@ -154,7 +154,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | |||
| auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | |||
| bool need_sync = false; | |||
| if (ms_context->enable_pynative_infer()) { | |||
| if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) { | |||
| if (tensor_address == nullptr || tensor_address != device_address) { | |||
| need_sync = true; | |||
| } | |||
| @@ -223,7 +223,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList | |||
| // Prepare ms context info for dump .pb graph | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| // Optimize | |||
| Optimize(graph); | |||
| // Select kernel build info | |||
| @@ -290,7 +290,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||
| // Summary | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->enable_gpu_summary()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY)) { | |||
| Summary(kernel_graph.get()); | |||
| } | |||
| #ifdef ENABLE_DEBUGGER | |||
| @@ -268,7 +268,7 @@ void MSInferSession::RegAllOp() { | |||
| return; | |||
| } | |||
| Initialized = true; | |||
| MsContext::GetInstance()->set_execution_mode(kGraphMode); | |||
| MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| Py_Initialize(); | |||
| auto c_expression = PyImport_ImportModule("mindspore._c_expression"); | |||
| if (c_expression == nullptr) { | |||
| @@ -357,13 +357,13 @@ Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) { | |||
| MS_LOG(ERROR) << "Get Context failed!"; | |||
| return FAILED; | |||
| } | |||
| ms_context->set_execution_mode(kGraphMode); | |||
| ms_context->set_device_id(device_id); | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id); | |||
| auto ajust_device = AjustTargetName(device); | |||
| if (ajust_device == "") { | |||
| return FAILED; | |||
| } | |||
| ms_context->set_device_target(device); | |||
| ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, device); | |||
| if (!context::OpenTsd(ms_context)) { | |||
| MS_LOG(ERROR) << "Session init OpenTsd failed!"; | |||
| return FAILED; | |||
| @@ -93,10 +93,11 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o | |||
| // if in paynative mode,data only copyed to host when user want to print data | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() != kPynativeMode && ms_context->device_target() != kGPUDevice) { | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | |||
| ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { | |||
| tensor->set_need_sync(true); | |||
| } | |||
| if (ms_context->execution_mode() != kPynativeMode) { | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| tensor->SetNeedWait(true); | |||
| } | |||
| tensor->set_dirty(false); | |||
| @@ -938,7 +939,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0); | |||
| if (ms_context->enable_pynative_infer()) { | |||
| if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) { | |||
| return tensor->device_address().get() == nullptr || tensor->device_address() != device_address; | |||
| } | |||
| if (tensor->is_dirty()) { | |||
| @@ -979,7 +980,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | |||
| if (ms_context->execution_mode() == kPynativeMode || | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode || | |||
| AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) { | |||
| tensor->set_device_address(device_address); | |||
| } | |||
| @@ -1177,7 +1178,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: | |||
| if (backend_anf != nullptr) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->execution_mode() == kPynativeMode) { | |||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| return backend_anf; | |||
| } | |||
| @@ -118,7 +118,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| debugger_ = Debugger::GetInstance(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| debugger_->Init(device_id_, ms_context->device_target()); | |||
| debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET)); | |||
| } | |||
| #endif | |||
| @@ -53,7 +53,7 @@ bool DataDumpParser::DumpEnabled() const { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| if (context->execution_mode() == kPynativeMode) { | |||
| if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| MS_LOG(EXCEPTION) << "[DataDump] PyNative mode not support data dump"; | |||
| } | |||
| return true; | |||
| @@ -142,7 +142,7 @@ void Debugger::EnableDebugger() { | |||
| // switch memory reuse on or off | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| context_ptr->set_enable_mem_reuse(partial_memory_); | |||
| context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, partial_memory_); | |||
| // print some message about memory reuse to user | |||
| if (partial_memory_) { | |||
| MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first " | |||
| @@ -530,7 +530,7 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { | |||
| MS_LOG(ERROR) << "ms_context is nullptr"; | |||
| return; | |||
| } | |||
| auto save_graphs_path = ms_context->save_graphs_path(); | |||
| auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -112,7 +112,7 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| // dump_enable_ is true, close mem reuse | |||
| context_ptr->set_enable_mem_reuse(!dump_enable_); | |||
| context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, !dump_enable_); | |||
| trans_flag_ = trans_flag; | |||
| dump_mode_ = mode; | |||
| dump_path_ = path; | |||
| @@ -135,7 +135,7 @@ bool Dump::SetDumpConfFromJsonFile() { | |||
| } | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto id = context_ptr->device_id(); | |||
| auto id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| char real_path[PATH_MAX] = {0}; | |||
| if (nullptr == realpath(config_path_str, real_path)) { | |||
| MS_LOG(ERROR) << "Env e2e dump path error, " << config_path_str; | |||
| @@ -34,7 +34,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt | |||
| manager_ptr->AddFuncGraph(func_graph); | |||
| auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { | |||
| if (MsContext::GetInstance()->is_multi_graph_sink()) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) { | |||
| if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | |||
| f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| } | |||
| @@ -182,7 +182,7 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp | |||
| void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool check_bprop_flag = context->check_bprop_flag(); | |||
| bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG); | |||
| // Skip checking if check_bprop not set | |||
| if (!check_bprop_flag) { | |||
| return; | |||
| @@ -29,7 +29,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr | |||
| PConstant const_2(node); | |||
| PConstant any_const(node); | |||
| if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| MATCH_REPLACE(node, x + zero_, x); // Add by zero | |||
| MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero | |||
| MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero | |||
| @@ -41,7 +41,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr | |||
| } | |||
| // Prim Eliminate (identity) | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); | |||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| return nullptr; | |||
| } | |||
| @@ -75,7 +75,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr | |||
| } | |||
| AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| return nullptr; | |||
| } | |||
| PatternNode x, y; | |||
| @@ -181,7 +181,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| } | |||
| }; | |||
| use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); | |||
| if (is_on_debug_ && MsContext::GetInstance()->save_graphs_flag()) { | |||
| if (is_on_debug_ && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; | |||
| auto fg_name = | |||
| "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | |||
| @@ -217,8 +217,8 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph | |||
| void DrawNode(string name, AnfNodePtr node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -44,7 +44,7 @@ std::vector<PrimitivePtr> FindPrimtive(const FuncGraphPtr &graph, const std::str | |||
| } | |||
| void DumpGraph(const FuncGraphPtr &root, const std::string &name) { | |||
| if (MsContext::GetInstance()->save_graphs_flag()) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| draw::Draw(name + ".dot", root); | |||
| DumpIR(name + ".ir", root); | |||
| ExportIR(name + ".dat", "0", root); | |||
| @@ -69,7 +69,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||
| struct timeval start_time, end_time; | |||
| (void)gettimeofday(&start_time, nullptr); | |||
| if (MsContext::GetInstance()->save_graphs_flag()) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root); | |||
| } | |||
| MS_LOG(INFO) << "Now entering step auto parallel"; | |||
| @@ -271,7 +271,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) | |||
| if (!result) { | |||
| MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; | |||
| } | |||
| if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && res->func_graph() != nullptr) { | |||
| auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first; | |||
| auto func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -295,20 +295,20 @@ bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, | |||
| static bool IsCtrlSink() { | |||
| auto ms_ctx = MsContext::GetInstance(); | |||
| if (ms_ctx->execution_mode() != kGraphMode) { | |||
| if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) { | |||
| return false; | |||
| } | |||
| std::string device_target = ms_ctx->device_target(); | |||
| std::string device_target = ms_ctx->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| if (device_target != kAscendDevice) { | |||
| return false; | |||
| } | |||
| if (!ms_ctx->enable_task_sink()) { | |||
| if (!ms_ctx->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) { | |||
| return false; | |||
| } | |||
| if (!ms_ctx->is_multi_graph_sink()) { | |||
| if (!ms_ctx->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) { | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -325,13 +325,13 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (CompileGraphs::ContainMixedTarget(func_graph)) { | |||
| bc_ptr->set_is_multi_graph_sink(false); | |||
| context_ptr->set_is_multi_graph_sink(false); | |||
| context_ptr->set_loop_sink_flag(false); | |||
| } else if (context_ptr->execution_mode() != kPynativeMode) { | |||
| std::string device_target = context_ptr->device_target(); | |||
| context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false); | |||
| context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false); | |||
| } else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| if (device_target == kAscendDevice && backend != kMsVm) { | |||
| bc_ptr->set_is_multi_graph_sink(true); | |||
| context_ptr->set_is_multi_graph_sink(true); | |||
| context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true); | |||
| } | |||
| } | |||
| @@ -49,7 +49,7 @@ inline std::string GetFilePathName(const std::string &file_name) { | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(EXCEPTION) << "ms_context is nullptr"; | |||
| } | |||
| auto save_graphs_path = ms_context->save_graphs_path(); | |||
| auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -48,6 +48,56 @@ using OpLib = mindspore::kernel::OpLib; | |||
| using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy; | |||
| using ParallelContext = mindspore::parallel::ParallelContext; | |||
| using CostModelContext = mindspore::parallel::CostModelContext; | |||
| using mindspore::MsCtxParam; | |||
| namespace mindspore { | |||
| void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) { | |||
| MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value) << "' of type '" | |||
| << py::str(value.get_type()) << "'."; | |||
| if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) { | |||
| ctx->set_param<bool>(param, value.cast<bool>()); | |||
| return; | |||
| } | |||
| if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) { | |||
| ctx->set_param<int>(param, value.cast<int>()); | |||
| return; | |||
| } | |||
| if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) { | |||
| ctx->set_param<uint32_t>(param, value.cast<uint32_t>()); | |||
| return; | |||
| } | |||
| if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) { | |||
| ctx->set_param<float>(param, value.cast<float>()); | |||
| return; | |||
| } | |||
| if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) { | |||
| ctx->set_param<std::string>(param, value.cast<std::string>()); | |||
| return; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " << py::str(value.get_type()); | |||
| } | |||
| py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) { | |||
| if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) { | |||
| return py::bool_(ctx->get_param<bool>(param)); | |||
| } | |||
| if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) { | |||
| return py::int_(ctx->get_param<int>(param)); | |||
| } | |||
| if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) { | |||
| return py::int_(ctx->get_param<uint32_t>(param)); | |||
| } | |||
| if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) { | |||
| return py::float_(ctx->get_param<float>(param)); | |||
| } | |||
| if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) { | |||
| return py::str(ctx->get_param<std::string>(param)); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Got illegal param " << param << "."; | |||
| } | |||
| } // namespace mindspore | |||
| // Interface with python | |||
| PYBIND11_MODULE(_c_expression, m) { | |||
| @@ -101,53 +151,48 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); | |||
| (void)m.def("ms_ctx_get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter."); | |||
| (void)m.def("ms_ctx_set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter."); | |||
| (void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic()) | |||
| .value("auto_mixed_precision_flag", MsCtxParam::MS_CTX_AUTO_MIXED_PRECISION_FLAG) | |||
| .value("check_bprop_flag", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG) | |||
| .value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP) | |||
| .value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL) | |||
| .value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY) | |||
| .value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL) | |||
| .value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL) | |||
| .value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK) | |||
| .value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE) | |||
| .value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK) | |||
| .value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER) | |||
| .value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION) | |||
| .value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE) | |||
| .value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK) | |||
| .value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG) | |||
| .value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK) | |||
| .value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT) | |||
| .value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY) | |||
| .value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING) | |||
| .value("save_graphs_flag", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG) | |||
| .value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY) | |||
| .value("execution_mode", MsCtxParam::MS_CTX_EXECUTION_MODE) | |||
| .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) | |||
| .value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) | |||
| .value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH) | |||
| .value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS) | |||
| .value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH) | |||
| .value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH) | |||
| .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) | |||
| .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) | |||
| .value("ge_ref", MsCtxParam::MS_CTX_GE_REF) | |||
| .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) | |||
| .value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF); | |||
| (void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(m, "MSContext") | |||
| .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") | |||
| .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") | |||
| .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.") | |||
| .def("get_execution_mode", &mindspore::MsContext::execution_mode, "Get execution mode.") | |||
| .def("set_execution_mode", &mindspore::MsContext::set_execution_mode, "Set execution mode.") | |||
| .def("set_precompile_only", &mindspore::MsContext::set_precompile_only, "Set enable precompile only.") | |||
| .def("get_precompile_only", &mindspore::MsContext::precompile_only, "Get enable precompile only.") | |||
| .def("get_device_target", &mindspore::MsContext::device_target, "Get device target.") | |||
| .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") | |||
| .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") | |||
| .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") | |||
| .def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.") | |||
| .def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.") | |||
| .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") | |||
| .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") | |||
| .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, | |||
| "Get whether to enable auto mixed precision.") | |||
| .def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag, | |||
| "Set whether to enable auto mixed precision.") | |||
| .def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision, | |||
| "Get whether to enable reduce precision.") | |||
| .def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision, | |||
| "Set whether to enable reduce precision.") | |||
| .def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.") | |||
| .def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.") | |||
| .def("get_enable_dump", &mindspore::MsContext::enable_dump, "Get whether to enable dump.") | |||
| .def("set_enable_dump", &mindspore::MsContext::set_enable_dump, "Set whether to enable dump.") | |||
| .def("get_save_dump_path", &mindspore::MsContext::save_dump_path, "Get path to dump.") | |||
| .def("set_save_dump_path", &mindspore::MsContext::set_save_dump_path, "Set path to dump.") | |||
| .def("set_graph_memory_max_size", &mindspore::MsContext::set_graph_memory_max_size, "set graph memory max size.") | |||
| .def("set_variable_memory_max_size", &mindspore::MsContext::set_variable_memory_max_size, | |||
| "set variable memory max size") | |||
| .def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.") | |||
| .def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.") | |||
| .def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.") | |||
| .def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.") | |||
| .def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.") | |||
| .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") | |||
| .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") | |||
| .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") | |||
| .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") | |||
| .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, | |||
| "Set the GraphKernel switch to on or off.") | |||
| .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") | |||
| .def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.") | |||
| .def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity."); | |||
| .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy."); | |||
| (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | |||
| .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | |||
| @@ -271,7 +271,7 @@ void InitOpt(const ResourcePtr &res) { | |||
| g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->enable_graph_kernel())) { | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| g_pass_opts["opt_graph_kernel_a"]->set_enable(false); | |||
| g_pass_opts["opt_graph_kernel_b"]->set_enable(false); | |||
| } | |||
| @@ -88,7 +88,7 @@ std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) { | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(EXCEPTION) << "ms_context is nullptr"; | |||
| } | |||
| auto save_graphs_path = ms_context->save_graphs_path(); | |||
| auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| @@ -646,7 +646,7 @@ void Pipeline::Run() { | |||
| if (!result) { | |||
| MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first; | |||
| } | |||
| if (MsContext::GetInstance()->save_graphs_flag() && resource_->func_graph() != nullptr) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && resource_->func_graph() != nullptr) { | |||
| auto graph = resource_->func_graph(); | |||
| if (graph != nullptr) { | |||
| user_graph = graph; | |||
| @@ -688,7 +688,7 @@ void Pipeline::Run() { | |||
| MsProfile::Reset(); | |||
| #endif | |||
| if (MsContext::GetInstance()->save_graphs_flag() && (user_graph != nullptr)) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && (user_graph != nullptr)) { | |||
| std::string user_graph_file = GetFilePathName("ModelDigraph.dot"); | |||
| MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file; | |||
| draw::DrawUserFuncGraph(user_graph_file, user_graph); | |||
| @@ -710,7 +710,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef | |||
| if (!succ) { | |||
| MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; | |||
| } | |||
| if (MsContext::GetInstance()->execution_mode() == 0 && !converted->isa<tensor::Tensor>()) { | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == 0 && !converted->isa<tensor::Tensor>()) { | |||
| MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString() | |||
| << " is not tensor."; | |||
| } | |||
| @@ -891,7 +891,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| // Convert CNodeList to LinConvertResult. | |||
| ConfigManager::GetInstance().set_iter_num(1); | |||
| auto runner = convert_fn({app_init}, ""); | |||
| if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| backend->Link(runner.graph_id); | |||
| } | |||
| ConfigManager::GetInstance().set_iter_num(size); | |||
| @@ -965,10 +965,11 @@ void InitHccl() { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| (void)context::OpenTsd(ms_context); | |||
| uint32_t device_id = ms_context->device_id(); | |||
| std::string device_name = ms_context->device_target(); | |||
| ms_context->set_enable_hccl(true); | |||
| if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) { | |||
| uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true); | |||
| if (ms_context->backend_policy() == "ms" && | |||
| ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) { | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| if (!runtime_instance->Init()) { | |||
| @@ -214,7 +214,7 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di | |||
| return false; | |||
| } | |||
| if (MsContext::GetInstance()->save_graphs_flag()) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| convertor.DrawComputeGraph(GetFilePathName("ge_graph.dot")); // for debug | |||
| convertor.DrawInitGraph(GetFilePathName("init_graph.dot")); // for debug | |||
| convertor.DrawSaveCheckpointGraph(GetFilePathName("save_checkpoint_graph.dot")); // for debug | |||
| @@ -244,7 +244,7 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co | |||
| } | |||
| FuncGraphPtr anf_graph = info.at(phase)->func_graph; | |||
| if (MsContext::GetInstance()->save_graphs_flag()) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| draw::Draw(GetFilePathName("anf_graph.dot"), anf_graph); // for debug | |||
| DumpIR(GetFilePathName("anf_graph.ir"), anf_graph, true); | |||
| } | |||
| @@ -118,8 +118,9 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||
| << ", current function call depth: " << engine->function_call_depth(); | |||
| AbstractBasePtr ret_base = nullptr; | |||
| engine->IncreaseFunctionCallDepth(); | |||
| if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) { | |||
| MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << "."; | |||
| if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) { | |||
| MS_LOG(EXCEPTION) << "Exceed function call depth limit " | |||
| << MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) << "."; | |||
| } | |||
| std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); | |||
| for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { | |||
| @@ -409,7 +410,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg | |||
| bparams.push_back(SensitivityTransform(orig_func_)); | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse = context->enable_sparse(); | |||
| bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE); | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), | |||
| [&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { | |||
| if (enable_sparse && arg_spec->isa<AbstractTensor>()) { | |||
| @@ -62,7 +62,7 @@ class Evaluator : public Base { | |||
| virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse = context->enable_sparse(); | |||
| bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE); | |||
| if (!enable_sparse) { | |||
| return nullptr; | |||
| } | |||
| @@ -290,7 +290,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| if (abs_base->isa<AbstractTensor>()) { | |||
| auto arg_tensor = dyn_cast<AbstractTensor>(abs_base); | |||
| dic["shape"] = arg_tensor->shape()->shape(); | |||
| if (MsContext::GetInstance()->execution_mode() == kGraphMode) { | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | |||
| const auto &min_shape = arg_tensor->shape()->min_shape(); | |||
| const auto &max_shape = arg_tensor->shape()->max_shape(); | |||
| if (!min_shape.empty() && !max_shape.empty()) { | |||
| @@ -558,8 +558,8 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||
| MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; | |||
| auto ms_context = MsContext::GetInstance(); | |||
| ms_context->set_enable_pynative_infer(true); | |||
| std::string device_target = ms_context->device_target(); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true); | |||
| std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| if (device_target != kAscendDevice && device_target != kGPUDevice) { | |||
| MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; | |||
| } | |||
| @@ -567,7 +567,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||
| if (session == nullptr) { | |||
| session = session::SessionFactory::Get().Create(device_target); | |||
| MS_EXCEPTION_IF_NULL(session); | |||
| session->Init(ms_context->device_id()); | |||
| session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)); | |||
| } | |||
| std::vector<tensor::TensorPtr> input_tensors; | |||
| @@ -578,7 +578,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||
| session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask); | |||
| EraseValueNodeTensor(tensors_mask, &input_tensors); | |||
| py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors); | |||
| ms_context->set_enable_pynative_infer(false); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | |||
| *status = PYNATIVE_SUCCESS; | |||
| MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms"; | |||
| return result; | |||
| @@ -1308,7 +1308,7 @@ void PynativeExecutor::Clear(const std::string &flag) { | |||
| // Maybe exit in the pynative runing op, so need reset pynative flag. | |||
| auto ms_context = MsContext::GetInstance(); | |||
| if (ms_context != nullptr) { | |||
| ms_context->set_enable_pynative_infer(false); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | |||
| } | |||
| ConfigManager::GetInstance().ResetIterNum(); | |||
| return; | |||
| @@ -89,7 +89,7 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) | |||
| MS_EXCEPTION(ValueError) << "For user define net bprop, the gradients number: " << grads.size() | |||
| << " is not equal to the args number: " << py_args.size() - 2 << "."; | |||
| } | |||
| if (MsContext::GetInstance()->check_bprop_flag()) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) { | |||
| for (size_t i = 0; i < grads.size(); i++) { | |||
| if (py::isinstance<tensor::Tensor>(py_args[i])) { | |||
| if (!py::isinstance<tensor::Tensor>(grads[i])) { | |||
| @@ -154,7 +154,7 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s | |||
| DeviceAddressPtr AssignLaunchMemory(size_t size, const std::string &format, TypeId type) { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| auto device_id = ms_context->device_id(); | |||
| auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| auto address_ptr = runtime_instance->AssignSingleOpLaunchMemory(size, format, type); | |||
| @@ -261,11 +261,12 @@ void AscendDeviceAddress::SyncStream() const { | |||
| MS_LOG(INFO) << "Start!"; | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() != kPynativeMode && !ms_context->enable_pynative_infer()) { | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | |||
| !ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) { | |||
| MS_LOG(INFO) << "Finish!"; | |||
| return; | |||
| } | |||
| auto device_id = ms_context->device_id(); | |||
| auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| auto ret = runtime_instance->SyncStream(); | |||
| @@ -348,7 +349,7 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| auto device_id = ms_context->device_id(); | |||
| auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| auto ret = | |||
| @@ -475,7 +476,8 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh | |||
| std::vector<size_t> device_shape = GetDeviceShape(&host_shape); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() != kGraphMode && ms_context->execution_mode() != kPynativeMode && | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode && | |||
| ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | |||
| type_id_name_map.find(type_id_) != type_id_name_map.end()) { | |||
| std::pair<std::string, std::string> type_format = std::make_pair(type_id_name_map.at(type_id_), format_); | |||
| if (use_trans_data.find(type_format) != use_trans_data.end()) { | |||
| @@ -158,7 +158,7 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std | |||
| bool AscendKernelRuntime::NeedDestroyHccl() { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!context_ptr->enable_hccl()) { | |||
| if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) { | |||
| MS_LOG(INFO) << "Hccl is not enabled"; | |||
| return false; | |||
| } | |||
| @@ -177,7 +177,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto ret = rtSetDevice(context_ptr->device_id()); | |||
| auto ret = rtSetDevice(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID)); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]"; | |||
| } | |||
| @@ -461,12 +461,12 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { | |||
| MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id(); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool is_task_sink = context_ptr->enable_task_sink(); | |||
| bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||
| if (!is_task_sink) { | |||
| return true; | |||
| } | |||
| #ifdef MEM_REUSE_DEBUG | |||
| if (!context_ptr->enable_mem_reuse()) { | |||
| if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE)) { | |||
| // Get normal graph ir for memreuse | |||
| mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph); | |||
| } | |||
| @@ -518,7 +518,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { | |||
| MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id(); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool is_task_sink = context_ptr->enable_task_sink(); | |||
| bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||
| if (!is_task_sink) { | |||
| return true; | |||
| } | |||
| @@ -658,7 +658,7 @@ bool AscendKernelRuntime::InitDevice() { | |||
| MS_LOG(ERROR) << "Get MsContext instance failed"; | |||
| return false; | |||
| } | |||
| if (context_ptr->enable_hccl()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) { | |||
| if (!HcclInit()) { | |||
| MS_LOG(ERROR) << "HcclInit init failed"; | |||
| return false; | |||
| @@ -746,7 +746,7 @@ bool AscendKernelRuntime::DestroyHccl() { | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "Hccl destroy successful, status = " << res << "."; | |||
| context_ptr->set_enable_hccl(false); | |||
| context_ptr->set_param<bool>(MS_CTX_ENABLE_HCCL, false); | |||
| return true; | |||
| } | |||
| @@ -43,7 +43,7 @@ void AscendMemoryManager::MallocDeviceMemory() { | |||
| uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto variable_memory_max_size = context->variable_memory_max_size(); | |||
| auto variable_memory_max_size = context->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE); | |||
| if (variable_memory_max_size == "0") { | |||
| return 0; | |||
| } | |||
| @@ -1373,7 +1373,7 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it | |||
| bool AscendStreamAssign::IsTaskSink() { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (!ms_context->enable_task_sink()) { | |||
| if (!ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) { | |||
| MS_LOG(INFO) << "Task sink mode is not enable"; | |||
| return false; | |||
| } else { | |||
| @@ -117,7 +117,7 @@ void DataDumper::SetOpMappingInfo(NotNull<aicpu::dump::OpMappingInfo *> dump_inf | |||
| if (!dump_path.has_value()) { | |||
| MS_LOG(EXCEPTION) << "Dump path invalid"; | |||
| } | |||
| auto device_id = context_ptr->device_id(); | |||
| auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| dump_info->set_dump_path("/" + dump_path.value() + "_" + std::to_string(device_id) + "/"); | |||
| MS_LOG(INFO) << "[DataDump] dump_path:" << dump_path.value(); | |||
| @@ -363,7 +363,7 @@ void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index, | |||
| *precision_reduce = false; | |||
| return; | |||
| } | |||
| if (context_ptr->enable_reduce_precision()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION)) { | |||
| selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, | |||
| kernel_support_datatype, &kernel_match_datatype_idx_copy); | |||
| } | |||
| @@ -117,7 +117,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) { | |||
| } | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| const string prof_options_str = context->profiling_options(); | |||
| const string prof_options_str = context->get_param<std::string>(MS_CTX_PROFILING_OPTIONS); | |||
| std::vector<string> opts = Split(prof_options_str, ':'); | |||
| if (opts.empty()) { | |||
| MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!"; | |||
| @@ -41,7 +41,7 @@ class ProfilingManager { | |||
| inline bool IsProfiling() const { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| return context->enable_profiling(); | |||
| return context->get_param<bool>(MS_CTX_ENABLE_PROFILING); | |||
| } | |||
| protected: | |||
| @@ -342,12 +342,12 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids, | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second); | |||
| TaskDescReporter task_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.task_desc_info", ret->second); | |||
| task_reporter.set_task_ids(task_ids); | |||
| task_reporter.set_stream_ids(stream_ids); | |||
| task_reporter.ReportData(); | |||
| GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second); | |||
| GraphDescReporter graph_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.graph_desc_info", ret->second); | |||
| graph_profiling_cnode_.erase(ret); | |||
| graph_reporter.ReportData(); | |||
| @@ -357,7 +357,7 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids, | |||
| MS_LOG(ERROR) << "Graph id not found in graph_point"; | |||
| return; | |||
| } | |||
| PointReporter point_reporter(context->device_id(), "vm.point"); | |||
| PointReporter point_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.point"); | |||
| for (const auto &point : point_iter->second) { | |||
| point_reporter.AddReportData(point); | |||
| } | |||
| @@ -416,7 +416,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { | |||
| mem_manager_->ResetDynamicMemory(); | |||
| AssignStaticMemoryInput(graph); | |||
| AssignStaticMemoryValueNode(graph); | |||
| bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); | |||
| bool is_enable_dynamic_mem = context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL); | |||
| if (is_enable_dynamic_mem) { | |||
| // Use the dynamic memory pool. | |||
| InitKernelRefCount(graph); | |||
| @@ -435,8 +435,8 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { | |||
| bool ret = true; | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); | |||
| bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); | |||
| bool is_enable_dynamic_mem = context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL); | |||
| bool is_enable_pynative_infer = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER); | |||
| if (is_enable_dynamic_mem && !is_enable_pynative_infer) { | |||
| auto graph_id = graph->graph_id(); | |||
| auto iter = mem_swap_map_.find(graph_id); | |||
| @@ -29,7 +29,7 @@ bool GPUMemoryAllocator::Init() { | |||
| size_t free_size = CudaDriver::free_mem_size(); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| limited_device_memory_ = context_ptr->max_device_memory(); | |||
| limited_device_memory_ = context_ptr->get_param<float>(MS_CTX_MAX_DEVICE_MEMORY); | |||
| available_device_memory_ = FloatToSize(limited_device_memory_ * 1024 * 1024 * 1024); | |||
| if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) { | |||
| MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size | |||
| @@ -44,7 +44,7 @@ bool GPUMemoryAllocator::Init() { | |||
| void GPUMemoryAllocator::CheckMaxDeviceMemory() const { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto max_device_memory = context_ptr->max_device_memory(); | |||
| auto max_device_memory = context_ptr->get_param<float>(MS_CTX_MAX_DEVICE_MEMORY); | |||
| // Currently not support modifying the max device memory. | |||
| if (limited_device_memory_ != max_device_memory) { | |||
| MS_LOG(EXCEPTION) | |||
| @@ -37,7 +37,7 @@ void GPUMemoryManager::MallocDeviceMemory() { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| // If use the dynamic memory pool, then alloc the first memory block to init. | |||
| if (context_ptr->enable_dynamic_mem_pool()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) { | |||
| auto device_addr = MallocMemFromMemPool(1); | |||
| if (!device_addr) { | |||
| MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; | |||
| @@ -65,7 +65,7 @@ void GPUMemoryManager::FreeDeviceMemory() { | |||
| uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->enable_dynamic_mem_pool()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) { | |||
| auto device_ptr = MallocMemFromMemPool(size); | |||
| MS_EXCEPTION_IF_NULL(device_ptr); | |||
| return AddressOffset(device_ptr, 0); | |||
| @@ -162,7 +162,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co | |||
| bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kPynativeMode) { | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| return false; | |||
| } | |||
| if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) { | |||
| @@ -60,8 +60,8 @@ void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &k | |||
| bool KernelAdjust::NeedInsertSwitch() { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && | |||
| ConfigManager::GetInstance().iter_num() > 1); | |||
| return (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && | |||
| context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && ConfigManager::GetInstance().iter_num() > 1); | |||
| } | |||
| CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, | |||
| @@ -50,7 +50,7 @@ bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { | |||
| struct timeval start_time, end_time; | |||
| (void)gettimeofday(&start_time, nullptr); | |||
| #endif | |||
| bool is_task_sink = context_ptr->enable_task_sink(); | |||
| bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||
| if (is_task_sink) { | |||
| ret = RunTask(graph); | |||
| } else { | |||
| @@ -502,7 +502,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode | |||
| MS_LOG(INFO) << "communication op addr exist"; | |||
| continue; | |||
| } | |||
| if (context_ptr->enable_hccl()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) { | |||
| mem_size = mem_manager_->GetCommonAlignSize(mem_size); | |||
| } | |||
| total_size += mem_size; | |||
| @@ -646,7 +646,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const | |||
| DeviceAddressPtr address = nullptr; | |||
| address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) { | |||
| if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) && | |||
| !mem_manager_->MallocMemFromMemPool(address, node_size)) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size; | |||
| } else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size; | |||
| @@ -682,7 +683,8 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { | |||
| DeviceAddressPtr address = nullptr; | |||
| address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) { | |||
| if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) && | |||
| !mem_manager_->MallocMemFromMemPool(address, tensor_size)) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size; | |||
| } else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; | |||
| @@ -701,7 +703,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(mem_manager_); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); | |||
| bool is_enable_mem_reuse = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE); | |||
| auto mem_type = kDynamicMem; | |||
| if (is_enable_mem_reuse) { | |||
| mem_manager_->MallocReusedDynamicMem(graph); | |||
| @@ -54,7 +54,7 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, Me | |||
| uint8_t *ptr = nullptr; | |||
| if (AnfAlgo::IsCommunicationOp(node)) { | |||
| bool communication_mem = false; | |||
| if (context_ptr->enable_hccl()) { | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) { | |||
| communication_mem = true; | |||
| } | |||
| if (type == kStaticMem) { | |||
| @@ -1070,7 +1070,7 @@ void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNod | |||
| convertor.inputs_ = inputs; | |||
| (void)convertor.ConvertAllNode().BuildGraph(); | |||
| std::string name = graph_node->ToString() + "_ge_graph.dot"; | |||
| if (MsContext::GetInstance()->save_graphs_flag()) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| convertor.DrawComputeGraph(name); | |||
| } | |||
| branches_map_[node.get()] = *(convertor.df_graph_); | |||
| @@ -41,13 +41,13 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| if (ms_context_ptr->is_pynative_ge_init()) { | |||
| if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) { | |||
| return true; | |||
| } | |||
| if (ms_context_ptr->tsd_ref()) { | |||
| if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) { | |||
| MS_LOG(DEBUG) << "TDT Dataset client is already opened."; | |||
| ms_context_ptr->set_tsd_ref("++"); | |||
| ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF); | |||
| return true; | |||
| } | |||
| @@ -59,7 +59,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| unsigned int device_id; | |||
| unsigned int rank_size = 1; | |||
| device_id = ms_context_ptr->device_id(); | |||
| device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| auto rank_size_env = common::GetEnv("RANK_SIZE"); | |||
| if (rank_size_env.empty()) { | |||
| @@ -79,7 +79,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << "."; | |||
| return false; | |||
| } | |||
| ms_context_ptr->set_tsd_ref("++"); | |||
| ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF); | |||
| #ifdef ENABLE_TDTQUE | |||
| int32_t initStatus = tdt::TdtHostInit(device_id); | |||
| if (initStatus != TDT_OK_CODE) { | |||
| @@ -88,7 +88,8 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| } | |||
| ms_context_ptr->tdt_print_ = std::thread(TensorPrint()); | |||
| #endif | |||
| MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << ms_context_ptr->tsd_ref() << "."; | |||
| MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " | |||
| << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << "."; | |||
| return true; | |||
| } | |||
| @@ -96,12 +97,12 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { | |||
| if (ms_context_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| if (ms_context_ptr->tsd_ref() == 0) { | |||
| if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) { | |||
| return true; | |||
| } | |||
| ms_context_ptr->set_tsd_ref("--"); | |||
| if (force || ms_context_ptr->tsd_ref() == 0) { | |||
| ms_context_ptr->set_tsd_ref(" "); | |||
| ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF); | |||
| if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) { | |||
| ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0); | |||
| #ifdef ENABLE_TDTQUE | |||
| int32_t stopStatus = tdt::TdtHostStop(KNpuLog); | |||
| if (stopStatus != TDT_OK_CODE) { | |||
| @@ -123,17 +124,17 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { | |||
| MS_LOG(ERROR) << "tdt thread join failed: " << e.what(); | |||
| } | |||
| #endif | |||
| auto device_id = ms_context_ptr->device_id(); | |||
| auto device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| TDT_StatusT status = TsdClose(device_id); | |||
| if (status != TDT_OK) { | |||
| MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << "."; | |||
| return false; | |||
| } | |||
| ms_context_ptr->set_pynative_ge_init(false); | |||
| ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false); | |||
| MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << "."; | |||
| } else { | |||
| MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " << ms_context_ptr->tsd_ref() | |||
| << "."; | |||
| MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " | |||
| << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << "."; | |||
| } | |||
| return true; | |||
| @@ -159,14 +160,14 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std | |||
| } | |||
| #ifdef ENABLE_GE | |||
| (*ge_options)["device_id"] = "0"; | |||
| (*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->enable_dump()); | |||
| (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->save_dump_path(); | |||
| (*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP)); | |||
| (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH); | |||
| (*ge_options)["ge.exec.dumpMode"] = "output"; | |||
| MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->enable_dump()) | |||
| << " and save dump path is " << ms_context_ptr->save_dump_path() << "."; | |||
| (*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->enable_profiling()); | |||
| if (ms_context_ptr->enable_profiling()) { | |||
| (*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->profiling_options(); | |||
| MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP)) | |||
| << " and save dump path is " << ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH) << "."; | |||
| (*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING)); | |||
| if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING)) { | |||
| (*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->get_param<std::string>(MS_CTX_PROFILING_OPTIONS); | |||
| } | |||
| (*ge_options)["rank_table_file"] = ""; | |||
| @@ -178,12 +179,12 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std | |||
| } | |||
| (*ge_options)["graphType"] = "1"; | |||
| if (ms_context_ptr->graph_memory_max_size() != "0") { | |||
| (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->graph_memory_max_size(); | |||
| if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") { | |||
| (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE); | |||
| } | |||
| if (ms_context_ptr->variable_memory_max_size() != "0") { | |||
| (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->variable_memory_max_size(); | |||
| if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") { | |||
| (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE); | |||
| } | |||
| #if ENABLE_TRAIN == 1 | |||
| @@ -224,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std | |||
| } | |||
| // Enable auto mixed precision according to the context options | |||
| if (ms_context_ptr->auto_mixed_precision_flag()) { | |||
| if (ms_context_ptr->get_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG)) { | |||
| (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision"; | |||
| } else { | |||
| (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; | |||
| @@ -240,7 +241,7 @@ void SetHcclOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<s | |||
| } | |||
| auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); | |||
| auto env_rank_id = common::GetEnv("RANK_ID"); | |||
| auto env_device_id = std::to_string(ms_context_ptr->device_id()); | |||
| auto env_device_id = std::to_string(ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID)); | |||
| if (!(env_table_file.empty() || env_rank_id.empty())) { | |||
| MS_LOG(INFO) << "Initialize Ge for distribute parameter"; | |||
| MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; | |||
| @@ -275,12 +276,12 @@ bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| #ifdef ENABLE_GE | |||
| if (ms_context_ptr->is_pynative_ge_init()) { | |||
| if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) { | |||
| return true; | |||
| } | |||
| if (ms_context_ptr->ge_ref()) { | |||
| ms_context_ptr->set_ge_ref("++"); | |||
| if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF)) { | |||
| ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF); | |||
| return true; | |||
| } | |||
| @@ -293,8 +294,8 @@ bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| MS_LOG(EXCEPTION) << "Initialize GE failed!"; | |||
| } | |||
| } | |||
| ms_context_ptr->set_ge_ref("++"); | |||
| MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->ge_ref() << "."; | |||
| ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF); | |||
| MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << "."; | |||
| #endif | |||
| return true; | |||
| } | |||
| @@ -303,12 +304,13 @@ bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| if (ms_context_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| if (ms_context_ptr->is_pynative_ge_init() || ms_context_ptr->ge_ref() || ms_context_ptr->tsd_ref()) { | |||
| if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT) || | |||
| ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) { | |||
| return true; | |||
| } | |||
| (void)OpenTsd(ms_context_ptr); | |||
| (void)InitGe(ms_context_ptr); | |||
| ms_context_ptr->set_pynative_ge_init(true); | |||
| ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true); | |||
| return true; | |||
| } | |||
| @@ -317,12 +319,12 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| #ifdef ENABLE_GE | |||
| if (ms_context_ptr->ge_ref() == 0) { | |||
| if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) { | |||
| return true; | |||
| } | |||
| ms_context_ptr->set_ge_ref("--"); | |||
| if (force || ms_context_ptr->ge_ref() == 0) { | |||
| ms_context_ptr->set_ge_ref(" "); | |||
| ms_context_ptr->decrease_param<uint32_t>(MS_CTX_GE_REF); | |||
| if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) { | |||
| ms_context_ptr->set_param<uint32_t>(MS_CTX_GE_REF, 0); | |||
| try { | |||
| DfGraphManager::GetInstance().DeleteGraphRunner(); | |||
| DfGraphManager::GetInstance().DeleteGeSession(); | |||
| @@ -337,7 +339,8 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { | |||
| } | |||
| ms_context_ptr->set_pynative_ge_init(false); | |||
| } else { | |||
| MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " << ms_context_ptr->ge_ref() << "."; | |||
| MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " | |||
| << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << "."; | |||
| } | |||
| #endif | |||
| return true; | |||
| @@ -347,14 +350,14 @@ bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| if (ms_context_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| return ms_context_ptr->IsTsdOpened(); | |||
| return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0; | |||
| } | |||
| bool IsGeInited(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| if (ms_context_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| return ms_context_ptr->IsGeInited(); | |||
| return ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) > 0; | |||
| } | |||
| // Register for device type. | |||
| @@ -353,7 +353,7 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py | |||
| // When sparse enabled, the undetermined might be raised and eliminated in opt passes | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse = context->enable_sparse(); | |||
| bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE); | |||
| if (enable_sparse) { | |||
| return std::make_shared<abstract::AbstractUndetermined>(); | |||
| } | |||
| @@ -273,7 +273,7 @@ void TensorPrint::operator()() { | |||
| prntpb::Print print; | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| std::string print_file_path = ms_context->print_file_path(); | |||
| std::string print_file_path = ms_context->get_param<std::string>(MS_CTX_PRINT_FILE_PATH); | |||
| if (print_file_path == "") { | |||
| while (true) { | |||
| std::vector<tdt::DataItem> bundle; | |||
| @@ -59,7 +59,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri | |||
| graph_id = target_sess_->CompileGraphAsync(lst, outputs); | |||
| } | |||
| if (MsContext::GetInstance()->precompile_only()) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) { | |||
| MS_LOG(INFO) << "PrecompileOnly, stop run graph"; | |||
| return result; | |||
| } | |||
| @@ -180,7 +180,7 @@ void MsBackend::CreateOtherSession(const std::string &target) { | |||
| } | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| uint32_t device_id = context_ptr->device_id(); | |||
| uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| other_sess_->Init(device_id); | |||
| other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); | |||
| other_device_ = target; | |||
| @@ -56,7 +56,7 @@ namespace { | |||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string last_target = context_ptr->device_target(); | |||
| std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| for (auto &node : nodes) { | |||
| if (node->isa<CNode>()) { | |||
| std::string cur_target = GetCNodeTarget(node); | |||
| @@ -348,7 +348,7 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { | |||
| if (prim->name() == prim::kPrimBpropCut->name()) { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| ms_context->set_enable_pynative_hook(true); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true); | |||
| } | |||
| if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { | |||
| @@ -412,7 +412,7 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||
| if (ContainMultiTarget(nodes)) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->device_target(); | |||
| std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| nodes = SplitSort(graph, default_target); | |||
| return SplitNodesWithTarget(nodes, graph); | |||
| } | |||
| @@ -920,17 +920,17 @@ BackendPtr CreateBackend() { | |||
| } | |||
| if (name == kMsConvert) { | |||
| std::string target = context_ptr->device_target(); | |||
| uint32_t device_id = context_ptr->device_id(); | |||
| std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| auto backend = std::make_shared<MsBackend>(name, target, device_id); | |||
| std::string device_target = MsContext::GetInstance()->device_target(); | |||
| std::string device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| if (device_target == kAscendDevice) { | |||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| backend->set_is_multi_graph_sink(false); | |||
| context_ptr->set_is_multi_graph_sink(false); | |||
| context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false); | |||
| } else { | |||
| backend->set_is_multi_graph_sink(true); | |||
| context_ptr->set_is_multi_graph_sink(true); | |||
| context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true); | |||
| } | |||
| } | |||
| return backend; | |||
| @@ -22,7 +22,7 @@ import threading | |||
| from collections import namedtuple | |||
| from types import FunctionType | |||
| from mindspore import log as logger | |||
| from mindspore._c_expression import MSContext | |||
| from mindspore._c_expression import MSContext, ms_ctx_param, ms_ctx_get_param, ms_ctx_set_param | |||
| from mindspore._checkparam import args_type_check | |||
| from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ | |||
| _reset_auto_parallel_context | |||
| @@ -157,9 +157,15 @@ class _Context: | |||
| raise ValueError("Context handle is none in context!!!") | |||
| return value | |||
| def get_param(self, param): | |||
| return ms_ctx_get_param(self._context_handle, param) | |||
| def set_param(self, param, value): | |||
| ms_ctx_set_param(self._context_handle, param, value) | |||
| @property | |||
| def mode(self): | |||
| return self._context_handle.get_execution_mode() | |||
| return self.get_param(ms_ctx_param.execution_mode) | |||
| @mode.setter | |||
| def mode(self, mode): | |||
| @@ -169,15 +175,17 @@ class _Context: | |||
| Args: | |||
| mode (int): GRAPH_MODE or PYNATIVE_MODE. | |||
| """ | |||
| self._context_handle.set_execution_mode(mode) | |||
| if mode == PYNATIVE_MODE: | |||
| if self.enable_debug_runtime: | |||
| self.set_backend_policy("vm") | |||
| self._context_switches.push(True, None) | |||
| else: | |||
| elif mode == GRAPH_MODE: | |||
| if self.enable_debug_runtime: | |||
| self.set_backend_policy("ge") | |||
| self._context_switches.push(False, None) | |||
| else: | |||
| raise ValueError(f'The execution mode {mode} is invalid!') | |||
| self.set_param(ms_ctx_param.execution_mode, mode) | |||
| def set_backend_policy(self, policy): | |||
| success = self._context_handle.set_backend_policy(policy) | |||
| @@ -186,110 +194,106 @@ class _Context: | |||
| @property | |||
| def precompile_only(self): | |||
| return self._context_handle.get_precompile_only() | |||
| return self.get_param(ms_ctx_param.precompile_only) | |||
| @precompile_only.setter | |||
| def precompile_only(self, precompile_only): | |||
| self._context_handle.set_precompile_only(precompile_only) | |||
| self.set_param(ms_ctx_param.precompile_only, precompile_only) | |||
| @property | |||
| def save_graphs(self): | |||
| return self._context_handle.get_save_graphs_flag() | |||
| return self.get_param(ms_ctx_param.save_graphs_flag) | |||
| @save_graphs.setter | |||
| def save_graphs(self, save_graphs_flag): | |||
| self._context_handle.set_save_graphs_flag(save_graphs_flag) | |||
| self.set_param(ms_ctx_param.save_graphs_flag, save_graphs_flag) | |||
| @property | |||
| def save_graphs_path(self): | |||
| return self._context_handle.get_save_graphs_path() | |||
| return self.get_param(ms_ctx_param.save_graphs_path) | |||
| @save_graphs_path.setter | |||
| def save_graphs_path(self, save_graphs_path): | |||
| self._context_handle.set_save_graphs_path( | |||
| _make_directory(save_graphs_path)) | |||
| self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path)) | |||
| @property | |||
| def device_target(self): | |||
| return self._context_handle.get_device_target() | |||
| return self.get_param(ms_ctx_param.device_target) | |||
| @device_target.setter | |||
| def device_target(self, target): | |||
| success = self._context_handle.set_device_target(target) | |||
| if not success: | |||
| raise ValueError("Target device name is invalid!!!") | |||
| if self.enable_debug_runtime and self.device_target == "CPU": | |||
| valid_targets = ["CPU", "GPU", "Ascend", "Davinci"] | |||
| if not target in valid_targets: | |||
| raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}") | |||
| if target == "Davinci": | |||
| target = "Ascend" | |||
| self.set_param(ms_ctx_param.device_target, target) | |||
| if self.enable_debug_runtime and target == "CPU": | |||
| self.set_backend_policy("vm") | |||
| @property | |||
| def device_id(self): | |||
| return self._context_handle.get_device_id() | |||
| return self.get_param(ms_ctx_param.device_id) | |||
| @device_id.setter | |||
| def device_id(self, device_id): | |||
| if device_id < 0 or device_id > 4095: | |||
| raise ValueError( | |||
| "Device id must be in [0, 4095], but got {}".format(device_id)) | |||
| success = self._context_handle.set_device_id(device_id) | |||
| if not success: | |||
| raise RuntimeError("Device id set failed!!!") | |||
| raise ValueError(f"Device id must be in [0, 4095], but got {device_id}") | |||
| self.set_param(ms_ctx_param.device_id, device_id) | |||
| @property | |||
| def max_call_depth(self): | |||
| return self._context_handle.get_max_call_depth() | |||
| return self.get_param(ms_ctx_param.max_call_depth) | |||
| @max_call_depth.setter | |||
| def max_call_depth(self, max_call_depth): | |||
| if max_call_depth <= 0: | |||
| raise ValueError( | |||
| "Max call depth must be greater than 0, but got {}".format(max_call_depth)) | |||
| self._context_handle.set_max_call_depth(max_call_depth) | |||
| raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}") | |||
| self.set_param(ms_ctx_param.max_call_depth, max_call_depth) | |||
| @property | |||
| def enable_auto_mixed_precision(self): | |||
| return self._context_handle.get_auto_mixed_precision_flag() | |||
| return self.get_param(ms_ctx_param.auto_mixed_precision_flag) | |||
| @enable_auto_mixed_precision.setter | |||
| def enable_auto_mixed_precision(self, enable_auto_mixed_precision): | |||
| self._context_handle.set_auto_mixed_precision_flag( | |||
| enable_auto_mixed_precision) | |||
| self.set_param(ms_ctx_param.auto_mixed_precision_flag, enable_auto_mixed_precision) | |||
| @property | |||
| def enable_reduce_precision(self): | |||
| return self._context_handle.get_enable_reduce_precision_flag() | |||
| return self.get_param(ms_ctx_param.enable_reduce_precision_flag) | |||
| @enable_reduce_precision.setter | |||
| def enable_reduce_precision(self, enable_reduce_precision): | |||
| self._context_handle.set_enable_reduce_precision_flag( | |||
| enable_reduce_precision) | |||
| self.set_param(ms_ctx_param.enable_reduce_precision_flag, enable_reduce_precision) | |||
| @property | |||
| def enable_dump(self): | |||
| return self._context_handle.get_enable_dump() | |||
| return self.get_param(ms_ctx_param.enable_dump) | |||
| @enable_dump.setter | |||
| def enable_dump(self, enable_dump): | |||
| self._context_handle.set_enable_dump(enable_dump) | |||
| self.set_param(ms_ctx_param.enable_dump, enable_dump) | |||
| @property | |||
| def save_dump_path(self): | |||
| return self._context_handle.get_save_dump_path() | |||
| return self.get_param(ms_ctx_param.save_dump_path) | |||
| @save_dump_path.setter | |||
| def save_dump_path(self, save_dump_path): | |||
| self._context_handle.set_save_dump_path(save_dump_path) | |||
| self.set_param(ms_ctx_param.save_dump_path, save_dump_path) | |||
| @property | |||
| def enable_profiling(self): | |||
| return self._context_handle.get_enable_profiling() | |||
| return self.get_param(ms_ctx_param.enable_profiling) | |||
| @enable_profiling.setter | |||
| def enable_profiling(self, flag): | |||
| self._context_handle.set_enable_profiling(flag) | |||
| self.set_param(ms_ctx_param.enable_profiling, flag) | |||
| @property | |||
| def profiling_options(self): | |||
| return self._context_handle.get_profiling_options() | |||
| return self.get_param(ms_ctx_param.profiling_options) | |||
| @profiling_options.setter | |||
| def profiling_options(self, option): | |||
| @@ -298,15 +302,15 @@ class _Context: | |||
| if option not in options: | |||
| raise ValueError("Profiling options must be in 'training_trace' 'task_trace' " | |||
| "'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.") | |||
| self._context_handle.set_profiling_options(option) | |||
| self.set_param(ms_ctx_param.profiling_options, option) | |||
| @property | |||
| def enable_graph_kernel(self): | |||
| return self._context_handle.get_enable_graph_kernel() | |||
| return self.get_param(ms_ctx_param.enable_graph_kernel) | |||
| @enable_graph_kernel.setter | |||
| def enable_graph_kernel(self, graph_kernel_switch_): | |||
| self._context_handle.set_enable_graph_kernel(graph_kernel_switch_) | |||
| self.set_param(ms_ctx_param.enable_graph_kernel, graph_kernel_switch_) | |||
| @property | |||
| def reserve_class_name_in_scope(self): | |||
| @@ -325,20 +329,14 @@ class _Context: | |||
| @variable_memory_max_size.setter | |||
| def variable_memory_max_size(self, variable_memory_max_size): | |||
| if not check_input_format(variable_memory_max_size): | |||
| raise ValueError( | |||
| "Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") | |||
| raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") | |||
| if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: | |||
| raise ValueError( | |||
| "Context param variable_memory_max_size should be less than 31GB.") | |||
| variable_memory_max_size_ = variable_memory_max_size[:- | |||
| 2] + " * 1024 * 1024 * 1024" | |||
| graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - \ | |||
| int(variable_memory_max_size[:-2]) | |||
| graph_memory_max_size_ = str( | |||
| graph_memory_max_size) + " * 1024 * 1024 * 1024" | |||
| self._context_handle.set_variable_memory_max_size( | |||
| variable_memory_max_size_) | |||
| self._context_handle.set_graph_memory_max_size(graph_memory_max_size_) | |||
| raise ValueError("Context param variable_memory_max_size should be less than 31GB.") | |||
| variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" | |||
| graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) | |||
| graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" | |||
| self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) | |||
| self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) | |||
| @property | |||
| def enable_ge(self): | |||
| @@ -355,15 +353,15 @@ class _Context: | |||
| @property | |||
| def check_bprop(self): | |||
| return self._context_handle.get_check_bprop_flag() | |||
| return self.get_param(ms_ctx_param.check_bprop_flag) | |||
| @check_bprop.setter | |||
| def check_bprop(self, check_bprop_flag): | |||
| self._context_handle.set_check_bprop_flag(check_bprop_flag) | |||
| self.set_param(ms_ctx_param.check_bprop_flag, check_bprop_flag) | |||
| @property | |||
| def max_device_memory(self): | |||
| return self._context_handle.get_max_device_memory() | |||
| return self.get_param(ms_ctx_param.max_device_memory) | |||
| @max_device_memory.setter | |||
| def max_device_memory(self, max_device_memory): | |||
| @@ -372,7 +370,7 @@ class _Context: | |||
| max_device_memory_value = float(max_device_memory[:-2]) | |||
| if max_device_memory_value == 0: | |||
| raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") | |||
| self._context_handle.set_max_device_memory(max_device_memory_value) | |||
| self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value) | |||
| @property | |||
| def print_file_path(self): | |||
| @@ -392,15 +390,15 @@ class _Context: | |||
| full_file_name = os.path.join(path, file_name) | |||
| else: | |||
| full_file_name = print_file_path | |||
| self._context_handle.set_print_file_path(full_file_name) | |||
| self.set_param(ms_ctx_param.print_file_path, full_file_name) | |||
| @property | |||
| def enable_sparse(self): | |||
| return self._context_handle.get_enable_sparse() | |||
| return self.get_param(ms_ctx_param.enable_sparse) | |||
| @enable_sparse.setter | |||
| def enable_sparse(self, enable_sparse): | |||
| self._context_handle.set_enable_sparse(enable_sparse) | |||
| self.set_param(ms_ctx_param.enable_sparse, enable_sparse) | |||
| def check_input_format(x): | |||
| import re | |||
| @@ -486,8 +484,6 @@ def set_auto_parallel_context(**kwargs): | |||
| full_batch (bool): Whether to load the whole batch on each device. Default: False. | |||
| enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in | |||
| data parallel training in the benefit of time and memory saving. | |||
| max_call_depth(int): Specify the function call depth limit. Default: 1000. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -501,7 +497,6 @@ def set_auto_parallel_context(**kwargs): | |||
| >>> context.set_auto_parallel_context(parameter_broadcast=False) | |||
| >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | |||
| >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") | |||
| >>> context.set_auto_parallel_context(max_call_depth=80) | |||
| """ | |||
| _set_auto_parallel_context(**kwargs) | |||
| @@ -603,6 +598,7 @@ def set_context(**kwargs): | |||
| a file by default, and turn off printing to the screen. If the file already exists, add a timestamp | |||
| suffix to the file. | |||
| enable_sparse (bool): Whether to enable sparsity feature. Default: False. | |||
| max_call_depth(int): Specify the function call depth limit. Default: 1000. | |||
| Raises: | |||
| ValueError: If input key is not an attribute in context. | |||
| @@ -623,6 +619,7 @@ def set_context(**kwargs): | |||
| >>> context.set_context(enable_profiling=True, profiling_options="training_trace") | |||
| >>> context.set_context(max_device_memory="3.5GB") | |||
| >>> context.set_context(print_file_path="print.pb") | |||
| >>> context.set_context(max_call_depth=80) | |||
| """ | |||
| for key, value in kwargs.items(): | |||
| if not hasattr(_context(), key): | |||
| @@ -51,7 +51,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse = context->enable_sparse(); | |||
| bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE); | |||
| if (enable_sparse && dflt->isa<AbstractTensor>()) { | |||
| auto dflt_tensor = dflt->cast<AbstractTensorPtr>(); | |||
| return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); | |||
| @@ -232,7 +232,7 @@ std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->device_target(); | |||
| std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| return default_target; | |||
| } | |||
| @@ -248,7 +248,7 @@ std::string GetTupleGetItemTarget(const CNodePtr &cnode, const PrimitivePtr &pri | |||
| std::string GetCNodeTarget(const AnfNodePtr &node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->device_target(); | |||
| std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| if (!node->isa<CNode>()) { | |||
| return default_target; | |||
| } | |||
| @@ -652,7 +652,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP | |||
| new_func_graph->set_param_default_value(item.first, cloner[item.second]); | |||
| } | |||
| if (MsContext::GetInstance()->is_multi_graph_sink()) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) { | |||
| if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | |||
| new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| } | |||
| @@ -30,7 +30,7 @@ abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() { | |||
| FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse = context->enable_sparse(); | |||
| bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE); | |||
| if (!enable_sparse) { | |||
| return nullptr; | |||
| } | |||
| @@ -32,49 +32,50 @@ std::map<std::string, MsBackendPolicy> MsContext::policy_map_ = {{"ge", kMsBacke | |||
| {"vm_prior", kMsBackendVmPrior}}; | |||
| MsContext::MsContext(const std::string &policy, const std::string &target) { | |||
| save_graphs_flag_ = false; | |||
| save_graphs_path_ = "."; | |||
| enable_dump_ = false; | |||
| save_dump_path_ = "."; | |||
| tsd_ref_ = 0; | |||
| ge_ref_ = 0; | |||
| is_multi_graph_sink_ = false; | |||
| is_pynative_ge_init_ = false; | |||
| enable_reduce_precision_ = true; | |||
| set_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG, false); | |||
| set_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH, "."); | |||
| set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, "."); | |||
| set_param<uint32_t>(MS_CTX_TSD_REF, 0); | |||
| set_param<uint32_t>(MS_CTX_GE_REF, 0); | |||
| set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false); | |||
| set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false); | |||
| set_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION, true); | |||
| auto env_device = common::GetEnv("DEVICE_ID"); | |||
| if (!env_device.empty()) { | |||
| device_id_ = UlongToUint(std::stoul(env_device.c_str())); | |||
| uint32_t device_id = UlongToUint(std::stoul(env_device.c_str())); | |||
| set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id); | |||
| } else { | |||
| device_id_ = 0; | |||
| set_param<uint32_t>(MS_CTX_DEVICE_ID, 0); | |||
| } | |||
| max_call_depth_ = MAX_CALL_DEPTH_DEFAULT; | |||
| backend_policy_ = policy_map_[policy]; | |||
| device_target_ = target; | |||
| execution_mode_ = kPynativeMode; | |||
| enable_task_sink_ = true; | |||
| ir_fusion_flag_ = true; | |||
| enable_hccl_ = false; | |||
| set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT); | |||
| set_param<std::string>(MS_CTX_DEVICE_TARGET, target); | |||
| set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode); | |||
| set_param<bool>(MS_CTX_ENABLE_TASK_SINK, true); | |||
| set_param<bool>(MS_CTX_IR_FUSION_FLAG, true); | |||
| set_param<bool>(MS_CTX_ENABLE_HCCL, false); | |||
| #ifdef ENABLE_DEBUGGER | |||
| enable_mem_reuse_ = false; | |||
| set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, false); | |||
| #else | |||
| enable_mem_reuse_ = true; | |||
| set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, true); | |||
| #endif | |||
| enable_gpu_summary_ = true; | |||
| precompile_only_ = false; | |||
| auto_mixed_precision_flag_ = false; | |||
| enable_pynative_infer_ = false; | |||
| enable_pynative_hook_ = false; | |||
| enable_dynamic_mem_pool_ = true; | |||
| graph_memory_max_size_ = "0"; | |||
| variable_memory_max_size_ = "0"; | |||
| enable_loop_sink_ = target == kAscendDevice || target == kDavinciDevice; | |||
| profiling_mode_ = false; | |||
| profiling_options_ = "training_trace"; | |||
| check_bprop_flag_ = false; | |||
| max_device_memory_ = kDefaultMaxDeviceMemory; | |||
| print_file_path_ = ""; | |||
| enable_graph_kernel_ = false; | |||
| enable_sparse_ = false; | |||
| set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true); | |||
| set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false); | |||
| set_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG, false); | |||
| set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | |||
| set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false); | |||
| set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true); | |||
| set_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE, "0"); | |||
| set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0"); | |||
| set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice); | |||
| set_param<bool>(MS_CTX_ENABLE_PROFILING, false); | |||
| set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace"); | |||
| set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false); | |||
| set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory); | |||
| set_param<std::string>(MS_CTX_PRINT_FILE_PATH, ""); | |||
| set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false); | |||
| set_param<bool>(MS_CTX_ENABLE_SPARSE, false); | |||
| backend_policy_ = policy_map_[policy]; | |||
| } | |||
| std::shared_ptr<MsContext> MsContext::GetInstance() { | |||
| @@ -106,54 +107,4 @@ std::string MsContext::backend_policy() const { | |||
| } | |||
| return "unknown"; | |||
| } | |||
| void MsContext::set_execution_mode(int execution_mode) { | |||
| if (execution_mode != kGraphMode && execution_mode != kPynativeMode) { | |||
| MS_LOG(EXCEPTION) << "The execution mode is invalid!"; | |||
| } | |||
| execution_mode_ = execution_mode; | |||
| } | |||
| bool MsContext::set_device_target(const std::string &target) { | |||
| if (kTargetSet.find(target) == kTargetSet.end()) { | |||
| MS_LOG(ERROR) << "invalid device target name: " << target; | |||
| return false; | |||
| } | |||
| if (target == kDavinciDevice) { | |||
| device_target_ = kAscendDevice; | |||
| } else { | |||
| device_target_ = target; | |||
| } | |||
| if (seter_) { | |||
| seter_(device_target_); | |||
| } | |||
| MS_LOG(INFO) << "ms set context device target:" << target; | |||
| return true; | |||
| } | |||
| bool MsContext::set_device_id(uint32_t device_id) { | |||
| device_id_ = device_id; | |||
| MS_LOG(INFO) << "ms set context device id:" << device_id; | |||
| return true; | |||
| } | |||
| void MsContext::set_tsd_ref(const std::string &op) { | |||
| if (op == "--") { | |||
| tsd_ref_--; | |||
| } else if (op == "++") { | |||
| tsd_ref_++; | |||
| } else { | |||
| tsd_ref_ = 0; | |||
| } | |||
| } | |||
| void MsContext::set_ge_ref(const std::string &op) { | |||
| if (op == "--") { | |||
| ge_ref_--; | |||
| } else if (op == "++") { | |||
| ge_ref_++; | |||
| } else { | |||
| ge_ref_ = 0; | |||
| } | |||
| } | |||
| } // namespace mindspore | |||
| @@ -49,6 +49,69 @@ const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, | |||
| // The default max available device memory is 1024GB. | |||
| const float kDefaultMaxDeviceMemory = 1024; | |||
| // enum definition for MindSpore Context Parameter | |||
| enum MsCtxParam : unsigned { | |||
| // paramater of type bool | |||
| MS_CTX_TYPE_BOOL_BEGIN, | |||
| MS_CTX_AUTO_MIXED_PRECISION_FLAG = MS_CTX_TYPE_BOOL_BEGIN, | |||
| MS_CTX_CHECK_BPROP_FLAG, | |||
| MS_CTX_ENABLE_DUMP, | |||
| MS_CTX_ENABLE_DYNAMIC_MEM_POOL, | |||
| MS_CTX_ENABLE_GPU_SUMMARY, | |||
| MS_CTX_ENABLE_GRAPH_KERNEL, | |||
| MS_CTX_ENABLE_HCCL, | |||
| MS_CTX_ENABLE_LOOP_SINK, | |||
| MS_CTX_ENABLE_MEM_REUSE, | |||
| MS_CTX_ENABLE_PYNATIVE_HOOK, | |||
| MS_CTX_ENABLE_PYNATIVE_INFER, | |||
| MS_CTX_ENABLE_REDUCE_PRECISION, | |||
| MS_CTX_ENABLE_SPARSE, | |||
| MS_CTX_ENABLE_TASK_SINK, | |||
| MS_CTX_IR_FUSION_FLAG, | |||
| MS_CTX_IS_MULTI_GRAPH_SINK, | |||
| MS_CTX_IS_PYNATIVE_GE_INIT, | |||
| MS_CTX_PRECOMPILE_ONLY, | |||
| MS_CTX_ENABLE_PROFILING, | |||
| MS_CTX_SAVE_GRAPHS_FLAG, | |||
| MS_CTX_TYPE_BOOL_END, | |||
| // paramater of type int | |||
| MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END, | |||
| MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN, | |||
| MS_CTX_TYPE_INT_END, | |||
| // paramater of type uint32 | |||
| MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END, | |||
| MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN, | |||
| MS_CTX_GE_REF, | |||
| MS_CTX_MAX_CALL_DEPTH, | |||
| MS_CTX_TSD_REF, | |||
| MS_CTX_TYPE_UINT32_END, | |||
| // paramater of type float | |||
| MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END, | |||
| MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN, | |||
| MS_CTX_TYPE_FLOAT_END, | |||
| // paramater of type string | |||
| MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END, | |||
| MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN, | |||
| MS_CTX_GRAPH_MEMORY_MAX_SIZE, | |||
| MS_CTX_PRINT_FILE_PATH, | |||
| MS_CTX_PROFILING_OPTIONS, | |||
| MS_CTX_SAVE_DUMP_PATH, | |||
| MS_CTX_SAVE_GRAPHS_PATH, | |||
| MS_CTX_VARIABLE_MEMORY_MAX_SIZE, | |||
| MS_CTX_TYPE_STRING_END, | |||
| // parameter numbers of each type | |||
| NUM_BOOL_PARAMS = MS_CTX_TYPE_BOOL_END - MS_CTX_TYPE_BOOL_BEGIN, | |||
| NUM_INT_PARAMS = MS_CTX_TYPE_INT_END - MS_CTX_TYPE_INT_BEGIN, | |||
| NUM_UINT32_PARAMS = MS_CTX_TYPE_UINT32_END - MS_CTX_TYPE_UINT32_BEGIN, | |||
| NUM_FLOAT_PARAMS = MS_CTX_TYPE_FLOAT_END - MS_CTX_TYPE_FLOAT_BEGIN, | |||
| NUM_STRING_PARAMS = MS_CTX_TYPE_STRING_END - MS_CTX_TYPE_STRING_BEGIN | |||
| }; | |||
| class MsContext { | |||
| public: | |||
| MsContext(const std::string &backend_policy, const std::string &target); | |||
| @@ -62,156 +125,113 @@ class MsContext { | |||
| std::string backend_policy() const; | |||
| bool set_backend_policy(const std::string &policy); | |||
| int execution_mode() const { return execution_mode_; } | |||
| void set_execution_mode(int execution_mode); | |||
| bool enable_pynative_infer() const { return enable_pynative_infer_; } | |||
| void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; } | |||
| bool enable_pynative_hook() const { return enable_pynative_hook_; } | |||
| void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; } | |||
| bool enable_task_sink() const { return enable_task_sink_; } | |||
| void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; } | |||
| bool precompile_only() const { return precompile_only_; } | |||
| std::string device_target() const { return device_target_; } | |||
| bool set_device_target(const std::string &target); | |||
| static void device_seter(DeviceSeter device) { seter_ = device; } | |||
| static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } | |||
| uint32_t device_id() const { return device_id_; } | |||
| bool set_device_id(uint32_t device_id); | |||
| std::thread tdt_print_; | |||
| // uint32_t max_call_depth_ | |||
| uint32_t max_call_depth() const { return max_call_depth_; } | |||
| inline bool set_max_call_depth(uint32_t max_call_depth) { | |||
| max_call_depth_ = max_call_depth; | |||
| return true; | |||
| template <typename T> | |||
| void set_param(MsCtxParam param, const T &value) { | |||
| MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||
| } | |||
| bool save_graphs_flag() const { return save_graphs_flag_; } | |||
| void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } | |||
| std::string save_graphs_path() const { return save_graphs_path_; } | |||
| void set_save_graphs_path(const std::string &save_paths) { save_graphs_path_ = save_paths; } | |||
| bool IsGeInited() { return ge_ref_ > 0; } | |||
| void set_enable_hccl(bool enable_hccl) { enable_hccl_ = enable_hccl; } | |||
| bool enable_hccl() const { return enable_hccl_; } | |||
| bool ir_fusion_flag() const { return ir_fusion_flag_; } | |||
| bool loop_sink_flag() const { return enable_loop_sink_; } | |||
| void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; } | |||
| void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; } | |||
| bool enable_mem_reuse() const { return enable_mem_reuse_; } | |||
| void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; } | |||
| bool enable_gpu_summary() const { return enable_gpu_summary_; } | |||
| void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) { | |||
| auto_mixed_precision_flag_ = auto_mixed_precision_flag; | |||
| template <typename T> | |||
| const T &get_param(MsCtxParam param) const { | |||
| MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||
| } | |||
| bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; } | |||
| void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; } | |||
| bool enable_reduce_precision() const { return enable_reduce_precision_; } | |||
| void set_enable_dump(bool flag) { enable_dump_ = flag; } | |||
| bool enable_dump() const { return enable_dump_; } | |||
| void set_save_dump_path(const std::string &path) { save_dump_path_ = path; } | |||
| std::string save_dump_path() const { return save_dump_path_; } | |||
| bool IsTsdOpened() const { return tsd_ref_ > 0; } | |||
| void set_tsd_ref(const std::string &op); | |||
| uint32_t tsd_ref() const { return tsd_ref_; } | |||
| void set_ge_ref(const std::string &op); | |||
| uint32_t ge_ref() const { return ge_ref_; } | |||
| bool is_pynative_ge_init() { return is_pynative_ge_init_; } | |||
| void set_pynative_ge_init(bool flag) { is_pynative_ge_init_ = flag; } | |||
| bool is_multi_graph_sink() const { return is_multi_graph_sink_; } | |||
| void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; } | |||
| void set_enable_dynamic_mem_pool(bool enable_dynamic_mem_pool) { enable_dynamic_mem_pool_ = enable_dynamic_mem_pool; } | |||
| bool enable_dynamic_mem_pool() const { return enable_dynamic_mem_pool_; } | |||
| void set_graph_memory_max_size(const std::string &graph_memory_max_size) { | |||
| graph_memory_max_size_ = graph_memory_max_size; | |||
| template <typename T> | |||
| void increase_param(MsCtxParam param) { | |||
| MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||
| } | |||
| void set_variable_memory_max_size(const std::string &variable_memory_max_size) { | |||
| variable_memory_max_size_ = variable_memory_max_size; | |||
| template <typename T> | |||
| void decrease_param(MsCtxParam param) { | |||
| MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||
| } | |||
| const std::string &variable_memory_max_size() const { return variable_memory_max_size_; } | |||
| const std::string &graph_memory_max_size() const { return graph_memory_max_size_; } | |||
| void set_enable_profiling(bool flag) { profiling_mode_ = flag; } | |||
| bool enable_profiling() const { return profiling_mode_; } | |||
| void set_profiling_options(const std::string &options) { profiling_options_ = options; } | |||
| std::string profiling_options() const { return profiling_options_; } | |||
| bool check_bprop_flag() const { return check_bprop_flag_; } | |||
| void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; } | |||
| void set_print_file_path(const std::string &file) { print_file_path_ = file; } | |||
| const std::string &print_file_path() const { return print_file_path_; } | |||
| float max_device_memory() const { return max_device_memory_; } | |||
| void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; } | |||
| void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } | |||
| bool enable_graph_kernel() const { return enable_graph_kernel_; } | |||
| bool enable_sparse() const { return enable_sparse_; } | |||
| void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; } | |||
| static void device_seter(DeviceSeter device) { seter_ = device; } | |||
| static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } | |||
| std::thread tdt_print_; | |||
| private: | |||
| inline static DeviceSeter seter_ = nullptr; | |||
| inline static DeviceTypeSeter device_type_seter_ = nullptr; | |||
| static std::shared_ptr<MsContext> inst_context_; | |||
| static std::map<std::string, MsBackendPolicy> policy_map_; | |||
| bool bool_params_[MsCtxParam::NUM_BOOL_PARAMS]; | |||
| int int_params_[MsCtxParam::NUM_INT_PARAMS]; | |||
| uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS]; | |||
| float float_params_[MsCtxParam::NUM_FLOAT_PARAMS]; | |||
| std::string string_params_[MsCtxParam::NUM_STRING_PARAMS]; | |||
| MsBackendPolicy backend_policy_; | |||
| std::string device_target_; | |||
| uint32_t device_id_; | |||
| uint32_t max_call_depth_; | |||
| int execution_mode_; | |||
| bool enable_pynative_infer_; | |||
| bool enable_pynative_hook_; | |||
| bool save_graphs_flag_; | |||
| std::string save_graphs_path_; | |||
| uint32_t tsd_ref_; | |||
| uint32_t ge_ref_; | |||
| bool enable_task_sink_; | |||
| bool enable_hccl_; | |||
| bool precompile_only_; | |||
| bool ir_fusion_flag_; | |||
| bool auto_mixed_precision_flag_; | |||
| bool enable_reduce_precision_; | |||
| bool enable_loop_sink_; | |||
| bool enable_mem_reuse_; | |||
| bool enable_gpu_summary_; | |||
| bool enable_dump_; | |||
| std::string save_dump_path_; | |||
| bool is_multi_graph_sink_; | |||
| bool is_pynative_ge_init_; | |||
| bool enable_dynamic_mem_pool_; | |||
| std::string graph_memory_max_size_; | |||
| std::string variable_memory_max_size_; | |||
| bool profiling_mode_; | |||
| std::string profiling_options_; | |||
| bool check_bprop_flag_; | |||
| float max_device_memory_; | |||
| std::string print_file_path_; | |||
| bool enable_graph_kernel_; | |||
| bool enable_sparse_; | |||
| }; | |||
| // set method implementation for type bool/int/uint32_t/float/std::string | |||
| template <> | |||
| inline void MsContext::set_param<bool>(MsCtxParam param, const bool &value) { | |||
| bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN] = value; | |||
| } | |||
| template <> | |||
| inline void MsContext::set_param<int>(MsCtxParam param, const int &value) { | |||
| int_params_[param - MS_CTX_TYPE_INT_BEGIN] = value; | |||
| } | |||
| template <> | |||
| inline void MsContext::set_param<uint32_t>(MsCtxParam param, const uint32_t &value) { | |||
| uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN] = value; | |||
| } | |||
| template <> | |||
| inline void MsContext::set_param<float>(MsCtxParam param, const float &value) { | |||
| float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN] = value; | |||
| } | |||
| template <> | |||
| inline void MsContext::set_param<std::string>(MsCtxParam param, const std::string &value) { | |||
| if (seter_ != nullptr && param == MS_CTX_DEVICE_TARGET) { | |||
| MS_LOG(INFO) << "ms set context device target:" << value; | |||
| seter_(value); | |||
| } | |||
| string_params_[param - MS_CTX_TYPE_STRING_BEGIN] = value; | |||
| } | |||
| // get method implementation for type bool/int/uint32_t/float/std::string | |||
| template <> | |||
| inline const bool &MsContext::get_param<bool>(MsCtxParam param) const { | |||
| return bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN]; | |||
| } | |||
| template <> | |||
| inline const int &MsContext::get_param<int>(MsCtxParam param) const { | |||
| return int_params_[param - MS_CTX_TYPE_INT_BEGIN]; | |||
| } | |||
| template <> | |||
| inline const uint32_t &MsContext::get_param<uint32_t>(MsCtxParam param) const { | |||
| return uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]; | |||
| } | |||
| template <> | |||
| inline const float &MsContext::get_param<float>(MsCtxParam param) const { | |||
| return float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN]; | |||
| } | |||
| template <> | |||
| inline const std::string &MsContext::get_param<std::string>(MsCtxParam param) const { | |||
| return string_params_[param - MS_CTX_TYPE_STRING_BEGIN]; | |||
| } | |||
| // increate method implementation for type uint32_t | |||
| template <> | |||
| inline void MsContext::increase_param<uint32_t>(MsCtxParam param) { | |||
| uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]++; | |||
| } | |||
| // decreate method implementation for type uint32_t | |||
| template <> | |||
| inline void MsContext::decrease_param<uint32_t>(MsCtxParam param) { | |||
| uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]--; | |||
| } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ | |||
| @@ -42,7 +42,7 @@ class TestOptLib : public UT::Common { | |||
| parse::data_converter::ClearObjectCache(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| ms_context->set_execution_mode(kGraphMode); | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| } | |||
| FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) { | |||
| equiv_node.clear(); | |||
| @@ -112,7 +112,7 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) { | |||
| */ | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| ms_context->set_execution_mode(kGraphMode); | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0"); | |||
| // Do insert_trans_op_ pass of hardware opt | |||
| auto graph_optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| @@ -112,7 +112,7 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect { | |||
| TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_single_output) { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| ms_context->set_execution_mode(kGraphMode); | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| auto kg = GetSingleOutputGraph("test_remove_internal_output_trans_op_for_single_output", "before"); | |||
| // insert trans op for output | |||
| auto graph_optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| @@ -104,7 +104,7 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) { | |||
| */ | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| ms_context->set_execution_mode(kGraphMode); | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before"); | |||
| std::vector<int> shp{2, 4, 8, 16}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| @@ -83,7 +83,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { | |||
| */ | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| ms_context->set_execution_mode(kGraphMode); | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before"); | |||
| std::vector<int> shp{2, 4, 8, 16}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| @@ -76,7 +76,7 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) { | |||
| */ | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| ms_context->set_execution_mode(kGraphMode); | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before"); | |||
| // Renormalize func_graph to infer and set shape and type information. | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| @@ -71,15 +71,15 @@ OpExecInfoPtr ConstructOpExecInfo() { | |||
| TEST_F(TestPynativeExecute, TestCreateContext) { | |||
| auto ctx3 = MsContext::GetInstance(); | |||
| ASSERT_EQ(ctx3->backend_policy(), "vm"); | |||
| ASSERT_EQ(ctx3->device_target(), "CPU"); | |||
| ASSERT_EQ(ctx3->get_param<std::string>(MS_CTX_DEVICE_TARGET), "CPU"); | |||
| ctx3->set_backend_policy("ge_only"); | |||
| ctx3->set_device_target("GPU"); | |||
| ctx3->set_param<std::string>(MS_CTX_DEVICE_TARGET, "GPU"); | |||
| auto ctx4 = MsContext::GetInstance(); | |||
| ASSERT_EQ(ctx3.get(), ctx4.get()); | |||
| ASSERT_EQ(ctx4->backend_policy(), "ge_only"); | |||
| ASSERT_EQ(ctx4->device_target(), "GPU"); | |||
| ASSERT_EQ(ctx4->get_param<std::string>(MS_CTX_DEVICE_TARGET), "GPU"); | |||
| } | |||
| TEST_F(TestPynativeExecute, TestDefaultContext) { | |||