Merge pull request !5576 from fary86/simplify_context_implementationtags/v1.0.0
| @@ -50,55 +50,6 @@ using ParallelContext = mindspore::parallel::ParallelContext; | |||||
| using CostModelContext = mindspore::parallel::CostModelContext; | using CostModelContext = mindspore::parallel::CostModelContext; | ||||
| using mindspore::MsCtxParam; | using mindspore::MsCtxParam; | ||||
| namespace mindspore { | |||||
| void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) { | |||||
| MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value) << "' of type '" | |||||
| << py::str(value.get_type()) << "'."; | |||||
| if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) { | |||||
| ctx->set_param<bool>(param, value.cast<bool>()); | |||||
| return; | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) { | |||||
| ctx->set_param<int>(param, value.cast<int>()); | |||||
| return; | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) { | |||||
| ctx->set_param<uint32_t>(param, value.cast<uint32_t>()); | |||||
| return; | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) { | |||||
| ctx->set_param<float>(param, value.cast<float>()); | |||||
| return; | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) { | |||||
| ctx->set_param<std::string>(param, value.cast<std::string>()); | |||||
| return; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " << py::str(value.get_type()); | |||||
| } | |||||
| py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) { | |||||
| if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) { | |||||
| return py::bool_(ctx->get_param<bool>(param)); | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) { | |||||
| return py::int_(ctx->get_param<int>(param)); | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) { | |||||
| return py::int_(ctx->get_param<uint32_t>(param)); | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) { | |||||
| return py::float_(ctx->get_param<float>(param)); | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) { | |||||
| return py::str(ctx->get_param<std::string>(param)); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Got illegal param " << param << "."; | |||||
| } | |||||
| } // namespace mindspore | |||||
| // Interface with python | // Interface with python | ||||
| PYBIND11_MODULE(_c_expression, m) { | PYBIND11_MODULE(_c_expression, m) { | ||||
| m.doc() = "MindSpore c plugin"; | m.doc() = "MindSpore c plugin"; | ||||
| @@ -151,49 +102,6 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); | (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); | ||||
| (void)m.def("ms_ctx_get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter."); | |||||
| (void)m.def("ms_ctx_set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter."); | |||||
| (void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic()) | |||||
| .value("auto_mixed_precision_flag", MsCtxParam::MS_CTX_AUTO_MIXED_PRECISION_FLAG) | |||||
| .value("check_bprop_flag", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG) | |||||
| .value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP) | |||||
| .value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL) | |||||
| .value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY) | |||||
| .value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL) | |||||
| .value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL) | |||||
| .value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK) | |||||
| .value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE) | |||||
| .value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK) | |||||
| .value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER) | |||||
| .value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION) | |||||
| .value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE) | |||||
| .value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK) | |||||
| .value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG) | |||||
| .value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK) | |||||
| .value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT) | |||||
| .value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY) | |||||
| .value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING) | |||||
| .value("save_graphs_flag", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG) | |||||
| .value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY) | |||||
| .value("execution_mode", MsCtxParam::MS_CTX_EXECUTION_MODE) | |||||
| .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) | |||||
| .value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) | |||||
| .value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH) | |||||
| .value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS) | |||||
| .value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH) | |||||
| .value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH) | |||||
| .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) | |||||
| .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) | |||||
| .value("ge_ref", MsCtxParam::MS_CTX_GE_REF) | |||||
| .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) | |||||
| .value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF); | |||||
| (void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(m, "MSContext") | |||||
| .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") | |||||
| .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") | |||||
| .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy."); | |||||
| (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | ||||
| .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | ||||
| .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.") | .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.") | ||||
| @@ -0,0 +1,117 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "utils/ms_context.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "pybind_api/api_register.h" | |||||
| namespace mindspore { | |||||
| namespace { | |||||
| void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) { | |||||
| MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value).cast<std::string>() << "' of type '" | |||||
| << py::str(value.get_type()).cast<std::string>() << "'."; | |||||
| if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) { | |||||
| ctx->set_param<bool>(param, value.cast<bool>()); | |||||
| return; | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) { | |||||
| ctx->set_param<int>(param, value.cast<int>()); | |||||
| return; | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) { | |||||
| ctx->set_param<uint32_t>(param, value.cast<uint32_t>()); | |||||
| return; | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) { | |||||
| ctx->set_param<float>(param, value.cast<float>()); | |||||
| return; | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) { | |||||
| ctx->set_param<std::string>(param, value.cast<std::string>()); | |||||
| return; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " | |||||
| << py::str(value.get_type()).cast<std::string>(); | |||||
| } | |||||
| py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) { | |||||
| if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) { | |||||
| return py::bool_(ctx->get_param<bool>(param)); | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) { | |||||
| return py::int_(ctx->get_param<int>(param)); | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) { | |||||
| return py::int_(ctx->get_param<uint32_t>(param)); | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) { | |||||
| return py::float_(ctx->get_param<float>(param)); | |||||
| } | |||||
| if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) { | |||||
| return py::str(ctx->get_param<std::string>(param)); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Got illegal param " << param << "."; | |||||
| } | |||||
| } // namespace | |||||
| REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) { | |||||
| (void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic()) | |||||
| .value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION) | |||||
| .value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG) | |||||
| .value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP) | |||||
| .value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL) | |||||
| .value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY) | |||||
| .value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL) | |||||
| .value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL) | |||||
| .value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK) | |||||
| .value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE) | |||||
| .value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK) | |||||
| .value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER) | |||||
| .value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION) | |||||
| .value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE) | |||||
| .value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK) | |||||
| .value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG) | |||||
| .value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK) | |||||
| .value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT) | |||||
| .value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY) | |||||
| .value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING) | |||||
| .value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG) | |||||
| .value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY) | |||||
| .value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE) | |||||
| .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) | |||||
| .value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) | |||||
| .value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH) | |||||
| .value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS) | |||||
| .value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH) | |||||
| .value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH) | |||||
| .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) | |||||
| .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) | |||||
| .value("ge_ref", MsCtxParam::MS_CTX_GE_REF) | |||||
| .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) | |||||
| .value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF); | |||||
| (void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext") | |||||
| .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") | |||||
| .def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.") | |||||
| .def("set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.") | |||||
| .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") | |||||
| .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy."); | |||||
| })); | |||||
| } // namespace mindspore | |||||
| @@ -225,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std | |||||
| } | } | ||||
| // Enable auto mixed precision according to the context options | // Enable auto mixed precision according to the context options | ||||
| if (ms_context_ptr->get_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG)) { | |||||
| if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_AUTO_MIXED_PRECISION)) { | |||||
| (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision"; | (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision"; | ||||
| } else { | } else { | ||||
| (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; | (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; | ||||
| @@ -337,7 +337,7 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { | |||||
| if (ge::GEFinalize() != ge::GRAPH_SUCCESS) { | if (ge::GEFinalize() != ge::GRAPH_SUCCESS) { | ||||
| MS_LOG(WARNING) << "Finalize GE failed!"; | MS_LOG(WARNING) << "Finalize GE failed!"; | ||||
| } | } | ||||
| ms_context_ptr->set_pynative_ge_init(false); | |||||
| ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false); | |||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " | MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " | ||||
| << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << "."; | << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << "."; | ||||
| @@ -22,7 +22,7 @@ import threading | |||||
| from collections import namedtuple | from collections import namedtuple | ||||
| from types import FunctionType | from types import FunctionType | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore._c_expression import MSContext, ms_ctx_param, ms_ctx_get_param, ms_ctx_set_param | |||||
| from mindspore._c_expression import MSContext, ms_ctx_param | |||||
| from mindspore._checkparam import args_type_check | from mindspore._checkparam import args_type_check | ||||
| from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ | from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ | ||||
| _reset_auto_parallel_context | _reset_auto_parallel_context | ||||
| @@ -158,17 +158,12 @@ class _Context: | |||||
| return value | return value | ||||
| def get_param(self, param): | def get_param(self, param): | ||||
| return ms_ctx_get_param(self._context_handle, param) | |||||
| return self._context_handle.get_param(param) | |||||
| def set_param(self, param, value): | def set_param(self, param, value): | ||||
| ms_ctx_set_param(self._context_handle, param, value) | |||||
| self._context_handle.set_param(param, value) | |||||
| @property | |||||
| def mode(self): | |||||
| return self.get_param(ms_ctx_param.execution_mode) | |||||
| @mode.setter | |||||
| def mode(self, mode): | |||||
| def set_mode(self, mode): | |||||
| """ | """ | ||||
| Switch between Graph mode and PyNative mode. | Switch between Graph mode and PyNative mode. | ||||
| @@ -185,43 +180,17 @@ class _Context: | |||||
| self._context_switches.push(False, None) | self._context_switches.push(False, None) | ||||
| else: | else: | ||||
| raise ValueError(f'The execution mode {mode} is invalid!') | raise ValueError(f'The execution mode {mode} is invalid!') | ||||
| self.set_param(ms_ctx_param.execution_mode, mode) | |||||
| self.set_param(ms_ctx_param.mode, mode) | |||||
| def set_backend_policy(self, policy): | def set_backend_policy(self, policy): | ||||
| success = self._context_handle.set_backend_policy(policy) | success = self._context_handle.set_backend_policy(policy) | ||||
| if not success: | if not success: | ||||
| raise RuntimeError("Backend policy must be one of ge, vm, ms.") | raise RuntimeError("Backend policy must be one of ge, vm, ms.") | ||||
| @property | |||||
| def precompile_only(self): | |||||
| return self.get_param(ms_ctx_param.precompile_only) | |||||
| @precompile_only.setter | |||||
| def precompile_only(self, precompile_only): | |||||
| self.set_param(ms_ctx_param.precompile_only, precompile_only) | |||||
| @property | |||||
| def save_graphs(self): | |||||
| return self.get_param(ms_ctx_param.save_graphs_flag) | |||||
| @save_graphs.setter | |||||
| def save_graphs(self, save_graphs_flag): | |||||
| self.set_param(ms_ctx_param.save_graphs_flag, save_graphs_flag) | |||||
| @property | |||||
| def save_graphs_path(self): | |||||
| return self.get_param(ms_ctx_param.save_graphs_path) | |||||
| @save_graphs_path.setter | |||||
| def save_graphs_path(self, save_graphs_path): | |||||
| def set_save_graphs_path(self, save_graphs_path): | |||||
| self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path)) | self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path)) | ||||
| @property | |||||
| def device_target(self): | |||||
| return self.get_param(ms_ctx_param.device_target) | |||||
| @device_target.setter | |||||
| def device_target(self, target): | |||||
| def set_device_target(self, target): | |||||
| valid_targets = ["CPU", "GPU", "Ascend", "Davinci"] | valid_targets = ["CPU", "GPU", "Ascend", "Davinci"] | ||||
| if not target in valid_targets: | if not target in valid_targets: | ||||
| raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}") | raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}") | ||||
| @@ -231,72 +200,17 @@ class _Context: | |||||
| if self.enable_debug_runtime and target == "CPU": | if self.enable_debug_runtime and target == "CPU": | ||||
| self.set_backend_policy("vm") | self.set_backend_policy("vm") | ||||
| @property | |||||
| def device_id(self): | |||||
| return self.get_param(ms_ctx_param.device_id) | |||||
| @device_id.setter | |||||
| def device_id(self, device_id): | |||||
| def set_device_id(self, device_id): | |||||
| if device_id < 0 or device_id > 4095: | if device_id < 0 or device_id > 4095: | ||||
| raise ValueError(f"Device id must be in [0, 4095], but got {device_id}") | raise ValueError(f"Device id must be in [0, 4095], but got {device_id}") | ||||
| self.set_param(ms_ctx_param.device_id, device_id) | self.set_param(ms_ctx_param.device_id, device_id) | ||||
| @property | |||||
| def max_call_depth(self): | |||||
| return self.get_param(ms_ctx_param.max_call_depth) | |||||
| @max_call_depth.setter | |||||
| def max_call_depth(self, max_call_depth): | |||||
| def set_max_call_depth(self, max_call_depth): | |||||
| if max_call_depth <= 0: | if max_call_depth <= 0: | ||||
| raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}") | raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}") | ||||
| self.set_param(ms_ctx_param.max_call_depth, max_call_depth) | self.set_param(ms_ctx_param.max_call_depth, max_call_depth) | ||||
| @property | |||||
| def enable_auto_mixed_precision(self): | |||||
| return self.get_param(ms_ctx_param.auto_mixed_precision_flag) | |||||
| @enable_auto_mixed_precision.setter | |||||
| def enable_auto_mixed_precision(self, enable_auto_mixed_precision): | |||||
| self.set_param(ms_ctx_param.auto_mixed_precision_flag, enable_auto_mixed_precision) | |||||
| @property | |||||
| def enable_reduce_precision(self): | |||||
| return self.get_param(ms_ctx_param.enable_reduce_precision_flag) | |||||
| @enable_reduce_precision.setter | |||||
| def enable_reduce_precision(self, enable_reduce_precision): | |||||
| self.set_param(ms_ctx_param.enable_reduce_precision_flag, enable_reduce_precision) | |||||
| @property | |||||
| def enable_dump(self): | |||||
| return self.get_param(ms_ctx_param.enable_dump) | |||||
| @enable_dump.setter | |||||
| def enable_dump(self, enable_dump): | |||||
| self.set_param(ms_ctx_param.enable_dump, enable_dump) | |||||
| @property | |||||
| def save_dump_path(self): | |||||
| return self.get_param(ms_ctx_param.save_dump_path) | |||||
| @save_dump_path.setter | |||||
| def save_dump_path(self, save_dump_path): | |||||
| self.set_param(ms_ctx_param.save_dump_path, save_dump_path) | |||||
| @property | |||||
| def enable_profiling(self): | |||||
| return self.get_param(ms_ctx_param.enable_profiling) | |||||
| @enable_profiling.setter | |||||
| def enable_profiling(self, flag): | |||||
| self.set_param(ms_ctx_param.enable_profiling, flag) | |||||
| @property | |||||
| def profiling_options(self): | |||||
| return self.get_param(ms_ctx_param.profiling_options) | |||||
| @profiling_options.setter | |||||
| def profiling_options(self, option): | |||||
| def set_profiling_options(self, option): | |||||
| options = ["training_trace", "task_trace", | options = ["training_trace", "task_trace", | ||||
| "task_trace:training_trace", "training_trace:task_trace", "op_trace"] | "task_trace:training_trace", "training_trace:task_trace", "op_trace"] | ||||
| if option not in options: | if option not in options: | ||||
| @@ -304,30 +218,7 @@ class _Context: | |||||
| "'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.") | "'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.") | ||||
| self.set_param(ms_ctx_param.profiling_options, option) | self.set_param(ms_ctx_param.profiling_options, option) | ||||
| @property | |||||
| def enable_graph_kernel(self): | |||||
| return self.get_param(ms_ctx_param.enable_graph_kernel) | |||||
| @enable_graph_kernel.setter | |||||
| def enable_graph_kernel(self, graph_kernel_switch_): | |||||
| self.set_param(ms_ctx_param.enable_graph_kernel, graph_kernel_switch_) | |||||
| @property | |||||
| def reserve_class_name_in_scope(self): | |||||
| """Gets whether to save the network class name in the scope.""" | |||||
| return self._thread_local_info.reserve_class_name_in_scope | |||||
| @reserve_class_name_in_scope.setter | |||||
| def reserve_class_name_in_scope(self, reserve_class_name_in_scope): | |||||
| """Sets whether to save the network class name in the scope.""" | |||||
| self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope | |||||
| @property | |||||
| def variable_memory_max_size(self): | |||||
| return None | |||||
| @variable_memory_max_size.setter | |||||
| def variable_memory_max_size(self, variable_memory_max_size): | |||||
| def set_variable_memory_max_size(self, variable_memory_max_size): | |||||
| if not check_input_format(variable_memory_max_size): | if not check_input_format(variable_memory_max_size): | ||||
| raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") | raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") | ||||
| if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: | if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: | ||||
| @@ -338,33 +229,7 @@ class _Context: | |||||
| self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) | self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) | ||||
| self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) | self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) | ||||
| @property | |||||
| def enable_ge(self): | |||||
| return self._context_handle.get_backend_policy() == 'ge' | |||||
| @property | |||||
| def enable_debug_runtime(self): | |||||
| return self._thread_local_info.debug_runtime | |||||
| @enable_debug_runtime.setter | |||||
| def enable_debug_runtime(self, enable): | |||||
| thread_info = self._thread_local_info | |||||
| thread_info.debug_runtime = enable | |||||
| @property | |||||
| def check_bprop(self): | |||||
| return self.get_param(ms_ctx_param.check_bprop_flag) | |||||
| @check_bprop.setter | |||||
| def check_bprop(self, check_bprop_flag): | |||||
| self.set_param(ms_ctx_param.check_bprop_flag, check_bprop_flag) | |||||
| @property | |||||
| def max_device_memory(self): | |||||
| return self.get_param(ms_ctx_param.max_device_memory) | |||||
| @max_device_memory.setter | |||||
| def max_device_memory(self, max_device_memory): | |||||
| def set_max_device_memory(self, max_device_memory): | |||||
| if not check_input_format(max_device_memory): | if not check_input_format(max_device_memory): | ||||
| raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") | raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") | ||||
| max_device_memory_value = float(max_device_memory[:-2]) | max_device_memory_value = float(max_device_memory[:-2]) | ||||
| @@ -372,12 +237,7 @@ class _Context: | |||||
| raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") | raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") | ||||
| self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value) | self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value) | ||||
| @property | |||||
| def print_file_path(self): | |||||
| return None | |||||
| @print_file_path.setter | |||||
| def print_file_path(self, file_path): | |||||
| def set_print_file_path(self, file_path): | |||||
| """Add timestamp suffix to file name. Sets print file path.""" | """Add timestamp suffix to file name. Sets print file path.""" | ||||
| print_file_path = os.path.realpath(file_path) | print_file_path = os.path.realpath(file_path) | ||||
| if os.path.isdir(print_file_path): | if os.path.isdir(print_file_path): | ||||
| @@ -392,13 +252,42 @@ class _Context: | |||||
| full_file_name = print_file_path | full_file_name = print_file_path | ||||
| self.set_param(ms_ctx_param.print_file_path, full_file_name) | self.set_param(ms_ctx_param.print_file_path, full_file_name) | ||||
| setters = { | |||||
| 'mode': set_mode, | |||||
| 'backend_policy': set_backend_policy, | |||||
| 'save_graphs_path': set_save_graphs_path, | |||||
| 'device_target': set_device_target, | |||||
| 'device_id': set_device_id, | |||||
| 'max_call_depth': set_max_call_depth, | |||||
| 'profiling_options': set_profiling_options, | |||||
| 'variable_memory_max_size': set_variable_memory_max_size, | |||||
| 'max_device_memory': set_max_device_memory, | |||||
| 'print_file_path': set_print_file_path | |||||
| } | |||||
| @property | @property | ||||
| def enable_sparse(self): | |||||
| return self.get_param(ms_ctx_param.enable_sparse) | |||||
| def reserve_class_name_in_scope(self): | |||||
| """Gets whether to save the network class name in the scope.""" | |||||
| return self._thread_local_info.reserve_class_name_in_scope | |||||
| @reserve_class_name_in_scope.setter | |||||
| def reserve_class_name_in_scope(self, reserve_class_name_in_scope): | |||||
| """Sets whether to save the network class name in the scope.""" | |||||
| self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope | |||||
| @property | |||||
| def enable_ge(self): | |||||
| return self._context_handle.get_backend_policy() == 'ge' | |||||
| @property | |||||
| def enable_debug_runtime(self): | |||||
| return self._thread_local_info.debug_runtime | |||||
| @enable_debug_runtime.setter | |||||
| def enable_debug_runtime(self, enable): | |||||
| thread_info = self._thread_local_info | |||||
| thread_info.debug_runtime = enable | |||||
| @enable_sparse.setter | |||||
| def enable_sparse(self, enable_sparse): | |||||
| self.set_param(ms_ctx_param.enable_sparse, enable_sparse) | |||||
| def check_input_format(x): | def check_input_format(x): | ||||
| import re | import re | ||||
| @@ -621,10 +510,18 @@ def set_context(**kwargs): | |||||
| >>> context.set_context(print_file_path="print.pb") | >>> context.set_context(print_file_path="print.pb") | ||||
| >>> context.set_context(max_call_depth=80) | >>> context.set_context(max_call_depth=80) | ||||
| """ | """ | ||||
| ctx = _context() | |||||
| for key, value in kwargs.items(): | for key, value in kwargs.items(): | ||||
| if not hasattr(_context(), key): | |||||
| raise ValueError("Set context keyword %s is not recognized!" % key) | |||||
| setattr(_context(), key, value) | |||||
| if hasattr(ctx, key): | |||||
| setattr(ctx, key, value) | |||||
| continue | |||||
| if key in ctx.setters: | |||||
| ctx.setters[key](ctx, value) | |||||
| continue | |||||
| if key in ms_ctx_param.__members__: | |||||
| ctx.set_param(ms_ctx_param.__members__[key], value) | |||||
| continue | |||||
| raise ValueError("Set context keyword %s is not recognized!" % key) | |||||
| def get_context(attr_key): | def get_context(attr_key): | ||||
| @@ -640,10 +537,13 @@ def get_context(attr_key): | |||||
| Raises: | Raises: | ||||
| ValueError: If input key is not an attribute in context. | ValueError: If input key is not an attribute in context. | ||||
| """ | """ | ||||
| if not hasattr(_context(), attr_key): | |||||
| raise ValueError( | |||||
| "Get context keyword %s is not recognized!" % attr_key) | |||||
| return getattr(_context(), attr_key) | |||||
| ctx = _context() | |||||
| if hasattr(ctx, attr_key): | |||||
| return getattr(ctx, attr_key) | |||||
| if attr_key in ms_ctx_param.__members__: | |||||
| return ctx.get_param(ms_ctx_param.__members__[attr_key]) | |||||
| raise ValueError("Get context keyword %s is not recognized!" % attr_key) | |||||
| class ParallelMode: | class ParallelMode: | ||||
| """ | """ | ||||
| @@ -60,7 +60,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { | |||||
| #endif | #endif | ||||
| set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true); | set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true); | ||||
| set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false); | set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false); | ||||
| set_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG, false); | |||||
| set_param<bool>(MS_CTX_ENABLE_AUTO_MIXED_PRECISION, false); | |||||
| set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | ||||
| set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false); | set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false); | ||||
| set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true); | set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true); | ||||
| @@ -53,7 +53,7 @@ const float kDefaultMaxDeviceMemory = 1024; | |||||
| enum MsCtxParam : unsigned { | enum MsCtxParam : unsigned { | ||||
| // paramater of type bool | // paramater of type bool | ||||
| MS_CTX_TYPE_BOOL_BEGIN, | MS_CTX_TYPE_BOOL_BEGIN, | ||||
| MS_CTX_AUTO_MIXED_PRECISION_FLAG = MS_CTX_TYPE_BOOL_BEGIN, | |||||
| MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN, | |||||
| MS_CTX_CHECK_BPROP_FLAG, | MS_CTX_CHECK_BPROP_FLAG, | ||||
| MS_CTX_ENABLE_DUMP, | MS_CTX_ENABLE_DUMP, | ||||
| MS_CTX_ENABLE_DYNAMIC_MEM_POOL, | MS_CTX_ENABLE_DYNAMIC_MEM_POOL, | ||||
| @@ -132,22 +132,22 @@ class MsContext { | |||||
| template <typename T> | template <typename T> | ||||
| void set_param(MsCtxParam param, const T &value) { | void set_param(MsCtxParam param, const T &value) { | ||||
| MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||||
| MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| const T &get_param(MsCtxParam param) const { | const T &get_param(MsCtxParam param) const { | ||||
| MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||||
| MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void increase_param(MsCtxParam param) { | void increase_param(MsCtxParam param) { | ||||
| MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||||
| MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void decrease_param(MsCtxParam param) { | void decrease_param(MsCtxParam param) { | ||||
| MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||||
| MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; | |||||
| } | } | ||||
| private: | private: | ||||