Browse Source

unify backend_policy using env variable

feature/build-system-rewrite
xiao_yao1994 4 years ago
parent
commit
521a4f8fb6
3 changed files with 20 additions and 13 deletions
  1. +7
    -1
      mindspore/ccsrc/utils/context/context_extends.cc
  2. +12
    -4
      mindspore/core/utils/ms_context.cc
  3. +1
    -8
      mindspore/python/mindspore/context.py

+ 7
- 1
mindspore/ccsrc/utils/context/context_extends.cc View File

@@ -15,6 +15,7 @@
*/

#include "utils/context/context_extends.h"
#include <cstdlib>
#include <map>
#include <string>
#include <memory>
@@ -377,7 +378,12 @@ struct DeviceTypeSetRegister {
DeviceTypeSetRegister() {
MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
#if defined(ENABLE_D)
device_type_seter.reset(new (std::nothrow) MsContext("ms", kAscendDevice));
auto enable_ge = std::getenv("MS_ENABLE_GE");
if (enable_ge != nullptr && std::string(enable_ge) == "1") {
device_type_seter.reset(new (std::nothrow) MsContext("ge", kAscendDevice));
} else {
device_type_seter.reset(new (std::nothrow) MsContext("ms", kAscendDevice));
}
#elif defined(ENABLE_GPU)
device_type_seter.reset(new (std::nothrow) MsContext("ms", kGPUDevice));
#else


+ 12
- 4
mindspore/core/utils/ms_context.cc View File

@@ -15,6 +15,7 @@
*/

#include "utils/ms_context.h"
#include <cstdlib>
#include <thread>
#include <atomic>
#include <fstream>
@@ -117,12 +118,19 @@ std::shared_ptr<MsContext> MsContext::GetInstance() {
}

bool MsContext::set_backend_policy(const std::string &policy) {
if (policy_map_.find(policy) == policy_map_.end()) {
MS_LOG(ERROR) << "invalid backend policy name: " << policy;
auto policy_new = policy;
#if defined(ENABLE_D)
auto enable_ge = std::getenv("MS_ENABLE_GE");
if (enable_ge != nullptr && std::string(enable_ge) == "1") {
policy_new = "ge";
}
#endif
if (policy_map_.find(policy_new) == policy_map_.end()) {
MS_LOG(ERROR) << "invalid backend policy name: " << policy_new;
return false;
}
backend_policy_ = policy_map_[policy];
MS_LOG(INFO) << "ms set context backend policy:" << policy;
backend_policy_ = policy_map_[policy_new];
MS_LOG(INFO) << "ms set context backend policy:" << policy_new;
return true;
}



+ 1
- 8
mindspore/python/mindspore/context.py View File

@@ -337,7 +337,6 @@ class _Context:
'mempool_block_size': set_mempool_block_size,
'print_file_path': set_print_file_path,
'env_config_path': set_env_config_path,
'backend_policy': set_backend_policy,
'runtime_num_threads': set_runtime_num_threads
}

@@ -615,8 +614,7 @@ def _check_target_specific_cfgs(device, arg_key):
enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int,
env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int,
load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str,
backend_policy=str)
load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str)
def set_context(**kwargs):
"""
Set context for running environment.
@@ -689,8 +687,6 @@ def set_context(**kwargs):
| | runtime_num_threads | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | compile_cache_path | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | backend_policy | Ascend |
+-------------------------+------------------------------+----------------------------+

Args:
@@ -828,9 +824,6 @@ def set_context(**kwargs):
If the specified directory does not exist, the system will automatically create the directory.
The cache will be saved to the directory of `compile_cache_path/rank_${rank_id}/`. The `rank_id` is
the ID of the current device in the cluster.
backend_policy (str): Used to choose a backend. ("ge", "vm" or "ms").
Through context.set_context(backend_policy="ms")
Default: The value must be in ['ge', 'vm', 'ms'].
runtime_num_threads(int): The thread pool number of cpu kernel and actor used in runtime,
which must bigger than 0.
Raises:


Loading…
Cancel
Save