Browse Source

add enable_auto_mixed_precision

tags/v0.3.0-alpha
jinyaohui 5 years ago
parent
commit
c0c694d99a
3 changed files with 19 additions and 1 deletions
  1. +4
    -0
      mindspore/ccsrc/pipeline/init.cc
  2. +5
    -0
      mindspore/ccsrc/utils/context/ms_context.h
  3. +10
    -1
      mindspore/context.py

+ 4
- 0
mindspore/ccsrc/pipeline/init.cc View File

@@ -117,6 +117,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
"Get whether to enable auto mixed precision.")
.def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag,
"Set whether to enable auto mixed precision.")
.def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision,
"Get whether to enable reduce precision.")
.def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision,


+ 5
- 0
mindspore/ccsrc/utils/context/ms_context.h View File

@@ -105,6 +105,11 @@ class MsContext {
void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; }
bool enable_gpu_summary() const { return enable_gpu_summary_; }

void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) {
auto_mixed_precision_flag_ = auto_mixed_precision_flag;
}
bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; }

void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; }
bool enable_reduce_precision() const { return enable_reduce_precision_; }



+ 10
- 1
mindspore/context.py View File

@@ -233,6 +233,14 @@ class _Context:
def save_ms_model_path(self, save_ms_model_path):
self._context_handle.set_save_ms_model_path(save_ms_model_path)

@property
def enable_auto_mixed_precision(self):
return self._context_handle.get_auto_mixed_precision_flag()

@enable_auto_mixed_precision.setter
def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
self._context_handle.set_auto_mixed_precision_flag(enable_auto_mixed_precision)

@property
def enable_reduce_precision(self):
return self._context_handle.get_enable_reduce_precision_flag()
@@ -441,7 +449,7 @@ def reset_auto_parallel_context():
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str)
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool)
def set_context(**kwargs):
"""
Sets context for running environment.
@@ -469,6 +477,7 @@ def set_context(**kwargs):
save_ms_model (bool): Whether to save lite model converted by graph. Default: False.
save_ms_model_path (str): Path to save converted lite model. Default: "."
save_graphs_path (str): Path to save graphs. Default: "."
enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: True.
reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
enable_dump (bool): Whether to enable dump. Default: False.


Loading…
Cancel
Save