Browse Source

Add device specific config check

tags/v1.0.0
fary86 5 years ago
parent
commit
b0f89685b4
3 changed files with 60 additions and 18 deletions
  1. +2
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
  2. +3
    -14
      mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
  3. +55
    -3
      mindspore/context.py

+ 2
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc View File

@@ -120,7 +120,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
engine->IncreaseFunctionCallDepth(); engine->IncreaseFunctionCallDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) { if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit " MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) << ".";
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
<< ", please call 'context.set_context(max_call_depth=value)' to adjust this value.";
} }
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {


+ 3
- 14
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc View File

@@ -71,40 +71,29 @@ py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam p
} }
} // namespace } // namespace


// Note: exported python enum variables begining with '_' are for internal use
REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) { REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic()) (void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
.value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION) .value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION)
.value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG) .value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP) .value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL) .value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION) .value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE) .value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY) .value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING) .value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG) .value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY) .value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE) .value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH) .value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS) .value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH) .value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH) .value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH);


(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext") (void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")


+ 55
- 3
mindspore/context.py View File

@@ -219,6 +219,7 @@ class _Context:
self.set_param(ms_ctx_param.profiling_options, option) self.set_param(ms_ctx_param.profiling_options, option)


def set_variable_memory_max_size(self, variable_memory_max_size): def set_variable_memory_max_size(self, variable_memory_max_size):
"""set values of variable_memory_max_size and graph_memory_max_size"""
if not _check_input_format(variable_memory_max_size): if not _check_input_format(variable_memory_max_size):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
@@ -227,7 +228,8 @@ class _Context:
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
# pylint: disable=protected-access
self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)


def set_max_device_memory(self, max_device_memory): def set_max_device_memory(self, max_device_memory):
if not _check_input_format(max_device_memory): if not _check_input_format(max_device_memory):
@@ -425,6 +427,26 @@ def reset_auto_parallel_context():
_reset_auto_parallel_context() _reset_auto_parallel_context()




def _check_target_specific_cfgs(device, arg_key):
"""Checking whether a config is sutable for a specified device"""
device_cfgs = {
'enable_auto_mixed_precision': ['Ascend'],
'enable_dump': ['Ascend'],
'enable_profiling': ['Ascend'],
'variable_memory_max_size': ['Ascend'],
'max_device_memory': ['GPU']
}
# configs not in map device_cfgs are supposed to be suitable for all devices
if not arg_key in device_cfgs:
return True
supported_devices = device_cfgs[arg_key]
if device in supported_devices:
return True
logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'"
", ignore it.")
return False


@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, enable_dump=bool, save_graphs_path=str, enable_dump=bool,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
@@ -450,6 +472,26 @@ def set_context(**kwargs):
The mode is not recommended to be changed after net was initilized because the implementations of some The mode is not recommended to be changed after net was initilized because the implementations of some
operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE. operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE.


Some configurations are device specific, see the bellow table for details:

=========================== =========================== =================
Common(CPU/GPU/Asecend) Ascend GPU
=========================== =========================== =================
check_bprop enable_auto_mixed_precision max_device_memory
device_id enable_dump
device_target enable_profiling
enable_graph_kernel variable_memory_max_size
enable_reduce_precision
enable_sparse
mode
print_file_path
profiling_options
reserve_class_name_in_scope
save_dump_path
save_graphs
save_graphs_path
=========================== =========================== =================

Args: Args:
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend". device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
@@ -513,14 +555,21 @@ def set_context(**kwargs):
>>> context.set_context(max_call_depth=80) >>> context.set_context(max_call_depth=80)
""" """
ctx = _context() ctx = _context()
# set device target first
if 'device_target' in kwargs:
ctx.set_device_target(kwargs['device_target'])
device = ctx.get_param(ms_ctx_param.device_target)
for key, value in kwargs.items(): for key, value in kwargs.items():
if not _check_target_specific_cfgs(device, key):
continue
if hasattr(ctx, key): if hasattr(ctx, key):
setattr(ctx, key, value) setattr(ctx, key, value)
continue continue
if key in ctx.setters: if key in ctx.setters:
ctx.setters[key](ctx, value) ctx.setters[key](ctx, value)
continue continue
if key in ms_ctx_param.__members__:
# enum variables begining with '_' are for internal use
if key in ms_ctx_param.__members__ and key[0] != '_':
ctx.set_param(ms_ctx_param.__members__[key], value) ctx.set_param(ms_ctx_param.__members__[key], value)
continue continue
raise ValueError("Set context keyword %s is not recognized!" % key) raise ValueError("Set context keyword %s is not recognized!" % key)
@@ -540,9 +589,12 @@ def get_context(attr_key):
ValueError: If input key is not an attribute in context. ValueError: If input key is not an attribute in context.
""" """
ctx = _context() ctx = _context()
device = ctx.get_param(ms_ctx_param.device_target)
_ = _check_target_specific_cfgs(device, attr_key)
if hasattr(ctx, attr_key): if hasattr(ctx, attr_key):
return getattr(ctx, attr_key) return getattr(ctx, attr_key)
if attr_key in ms_ctx_param.__members__:
# enum variables begining with '_' are for internal use
if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_':
return ctx.get_param(ms_ctx_param.__members__[attr_key]) return ctx.get_param(ms_ctx_param.__members__[attr_key])
raise ValueError("Get context keyword %s is not recognized!" % attr_key) raise ValueError("Get context keyword %s is not recognized!" % attr_key)




Loading…
Cancel
Save