From: @dayschan Reviewed-by: @gaoxiong1 Signed-off-by:pull/13720/MERGE
| @@ -21,6 +21,7 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/kernel_compiler/kernel_build_info.h" | |||
| @@ -67,6 +68,29 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { | |||
| prim::kPrimAssignAdd, | |||
| #endif | |||
| }; | |||
| auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); }; | |||
| auto &flags = context::GraphKernelFlags::GetInstance(); | |||
| auto &enable_ops_only = flags.enable_expand_ops_only; | |||
| if (!enable_ops_only.empty()) { | |||
| expand_ops.clear(); | |||
| std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::inserter(expand_ops, expand_ops.end()), | |||
| new_prim); | |||
| } else { | |||
| auto &enable_ops = flags.enable_expand_ops; | |||
| auto &disable_ops = flags.disable_expand_ops; | |||
| if (!enable_ops.empty()) { | |||
| std::transform(enable_ops.begin(), enable_ops.end(), std::inserter(expand_ops, expand_ops.end()), new_prim); | |||
| } | |||
| if (!disable_ops.empty()) { | |||
| for (auto iter = expand_ops.begin(); iter != expand_ops.end();) { | |||
| if (std::find(disable_ops.begin(), disable_ops.end(), (*iter)->name()) != disable_ops.end()) { | |||
| expand_ops.erase(iter++); | |||
| } else { | |||
| ++iter; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return expand_ops; | |||
| } | |||
| } // namespace | |||
| @@ -17,6 +17,7 @@ | |||
| #include "backend/optimizer/mem_reuse/mem_reuse.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| #include "backend/optimizer/mem_reuse/mem_reuse_checker.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| @@ -462,9 +463,7 @@ void MemReuseUtil::SetAllInfo(const KernelGraph *graph) { | |||
| MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); | |||
| #endif | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| enable_visit_kernel_cache_ = context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL); | |||
| enable_visit_kernel_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel(); | |||
| } | |||
| uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { | |||
| @@ -46,6 +46,7 @@ | |||
| #include "runtime/device/ascend/ascend_stream_assign.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "utils/config_manager.h" | |||
| @@ -846,9 +847,7 @@ void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_ | |||
| } | |||
| void AscendSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { | |||
| return; | |||
| } | |||
| opt::GraphKernelOptimize(kernel_graph); | |||
| @@ -69,6 +69,7 @@ | |||
| #include "utils/ms_utils.h" | |||
| #include "utils/config_manager.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| #include "utils/utils.h" | |||
| #if ENABLE_CPU && ENABLE_GPU | |||
| #include "ps/util.h" | |||
| @@ -127,8 +128,6 @@ void GPUSession::StartKernelRT() const { | |||
| void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>()); | |||
| @@ -136,7 +135,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>()); | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { | |||
| pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all")); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum")); | |||
| @@ -181,9 +180,7 @@ void GPUSession::RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kerne | |||
| } | |||
| void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { | |||
| return; | |||
| } | |||
| opt::GraphKernelOptimize(kernel_graph); | |||
| @@ -40,6 +40,7 @@ | |||
| #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" | |||
| #include "frontend/optimizer/recompute.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| #include "pipeline/jit/pipeline_split.h" | |||
| #include "pipeline/jit/static_analysis/auto_monad.h" | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| @@ -354,9 +355,7 @@ void InitOpt(const ResourcePtr &res) { | |||
| g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); | |||
| g_pass_opts["opt_after_recompute"] = | |||
| Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass)); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { | |||
| g_pass_opts["opt_graph_kernel_a"]->set_enable(false); | |||
| g_pass_opts["opt_graph_kernel_b"]->set_enable(false); | |||
| } | |||
| @@ -97,6 +97,7 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) { | |||
| .value("tune_mode", MsCtxParam::MS_CTX_TUNE_MODE) | |||
| .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) | |||
| .value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH) | |||
| .value("graph_kernel_flags", MsCtxParam::MS_CTX_GRAPH_KERNEL_FLAGS) | |||
| .value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR); | |||
| (void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext") | |||
| .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") | |||
| @@ -24,6 +24,7 @@ | |||
| #include "runtime/device/gpu/distribution/collective_init.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "runtime/device/gpu/gpu_common.h" | |||
| #include "utils/ms_utils.h" | |||
| @@ -66,9 +67,7 @@ bool GPUKernelRuntime::SyncStream() { | |||
| } | |||
| bool GPUKernelRuntime::Init() { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| enable_relation_cache_ = context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL); | |||
| enable_relation_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel(); | |||
| if (device_init_ == true) { | |||
| GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); | |||
| @@ -0,0 +1,196 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| #include <map> | |||
| #include <string> | |||
| #include <cstring> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "nlohmann/json.hpp" | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace context { | |||
| namespace { | |||
| // Split string to tokens | |||
| std::vector<std::string> GetTokens(const std::string &str, const std::string &delim) { | |||
| std::vector<std::string> tokens; | |||
| std::vector<char> c_str(str.begin(), str.end()); | |||
| c_str.push_back('\0'); | |||
| char *saveptr; | |||
| char *pch = strtok_r(&c_str[0], delim.c_str(), &saveptr); | |||
| while (pch != NULL) { | |||
| tokens.emplace_back(pch); | |||
| pch = strtok_r(NULL, delim.c_str(), &saveptr); | |||
| } | |||
| return tokens; | |||
| } | |||
| // Parse flag string to key-value pair. | |||
| // Flag format: "--key=value", bool flag's value can be implicit, the "--key" means "--key=true" | |||
| std::pair<std::string, std::string> ParseFlag(const std::string &flag) { | |||
| auto i = flag.find("--"); | |||
| // check the string starts with "--". | |||
| if (i != 0 || flag.size() == 2) { | |||
| return std::pair<std::string, std::string>(); | |||
| } | |||
| i += 2; | |||
| auto j = flag.find('=', i + 1); // the key should not be empty, "--=" is invalid | |||
| if (j == std::string::npos) { | |||
| // no value, treated as bool flag. | |||
| return std::make_pair(flag.substr(i), ""); | |||
| } else if (j + 1 != flag.size() && flag.find('=', j + 1) == std::string::npos) { | |||
| // normal "--key=value" format | |||
| return std::make_pair(flag.substr(i, j - i), flag.substr(j + 1)); | |||
| } | |||
| // string with two "=" is invalid. | |||
| return std::pair<std::string, std::string>(); | |||
| } | |||
| std::map<std::string, std::string> ParseFlags(const std::string &flags) { | |||
| std::map<std::string, std::string> flag_map; | |||
| auto tokens = GetTokens(flags, " "); | |||
| for (const auto &token : tokens) { | |||
| auto flag = ParseFlag(token); | |||
| if (flag.first != "") { | |||
| if (!flag_map.insert(flag).second) { | |||
| MS_LOG(WARNING) << "Repeated GraphKernel flag: " << flag.first; | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Invalid GraphKernel flag: " << token; | |||
| } | |||
| } | |||
| return flag_map; | |||
| } | |||
| class FlagRegister { | |||
| public: | |||
| explicit FlagRegister(std::map<std::string, std::string> *flag_map) : flag_map_(*flag_map) {} | |||
| ~FlagRegister() = default; | |||
| template <typename T> | |||
| void AddFlag(std::string flag_name, T *flag_var) { | |||
| auto iter = flag_map_.find(flag_name); | |||
| if (iter != flag_map_.end()) { | |||
| T var; | |||
| bool ret = ParseValue(iter->second, &var); | |||
| if (ret) { | |||
| *flag_var = std::move(var); | |||
| } else { | |||
| if (iter->second.empty()) { | |||
| MS_LOG(WARNING) << "Invalid GraphKernel flag: --" << iter->first; | |||
| } else { | |||
| MS_LOG(WARNING) << "Invalid GraphKernel flag: --" << iter->first << "=" << iter->second; | |||
| } | |||
| } | |||
| flag_map_.erase(iter); | |||
| } | |||
| } | |||
| private: | |||
| bool ParseValue(const std::string &s, std::vector<std::string> *result) { | |||
| *result = GetTokens(s, ","); | |||
| return !result->empty(); | |||
| } | |||
| bool ParseValue(const std::string &s, bool *result) { | |||
| *result = (s.empty() || s == "true" || s == "on" || s == "1"); | |||
| return *result || s == "false" || s == "off" || s == "0"; | |||
| } | |||
| template <typename T> | |||
| bool ParseValue(const std::string &s, T *result) { | |||
| if (s.empty()) { | |||
| return false; | |||
| } | |||
| std::istringstream iss(s); | |||
| iss >> (*result); | |||
| return iss.eof(); | |||
| } | |||
| template <typename T> | |||
| bool ParseValue(const std::string &s, std::vector<T> *result) { | |||
| result->clear(); | |||
| auto tokens = GetTokens(s, ","); | |||
| if (tokens.empty()) { | |||
| return false; | |||
| } | |||
| for (const auto &tok : tokens) { | |||
| T temp; | |||
| if (!ParseValue(tok, &temp)) { | |||
| result->clear(); | |||
| return false; | |||
| } | |||
| result->emplace_back(temp); | |||
| } | |||
| return true; | |||
| } | |||
| std::map<std::string, std::string> &flag_map_; | |||
| }; | |||
| } // namespace | |||
| void GraphKernelFlags::Refresh() { | |||
| auto flag_map = ParseFlags(flags_cache_); | |||
| RegisterFlags(&flag_map); | |||
| for (auto &item : flag_map) { | |||
| MS_LOG(WARNING) << "Unknown GraphKernel flag: " << item.first; | |||
| } | |||
| } | |||
| void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) { | |||
| FlagRegister reg(flag_map); | |||
| reg.AddFlag("dump_as_text", &dump_as_text); | |||
| reg.AddFlag("opt_level", &opt_level); | |||
| reg.AddFlag("auto_tune", &auto_tune); | |||
| reg.AddFlag("cluster_limit", &cluster_limit); | |||
| reg.AddFlag("enable_expand_ops", &enable_expand_ops); | |||
| reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only); | |||
| reg.AddFlag("disable_expand_ops", &disable_expand_ops); | |||
| reg.AddFlag("enable_cluster_ops", &enable_cluster_ops); | |||
| reg.AddFlag("enable_cluster_ops_only", &enable_cluster_ops_only); | |||
| reg.AddFlag("disable_cluster_ops", &disable_cluster_ops); | |||
| reg.AddFlag("enable_pass_only", &enable_pass_only); | |||
| reg.AddFlag("disable_pass", &disable_pass); | |||
| } | |||
| std::string GraphKernelFlags::DumpAllFlags() const { | |||
| nlohmann::json json; | |||
| json["dump_as_text"] = dump_as_text; | |||
| json["opt_level"] = opt_level; | |||
| json["auto_tune"] = auto_tune; | |||
| json["cluster_limit"] = cluster_limit; | |||
| json["enable_expand_ops"] = enable_expand_ops; | |||
| json["enable_expand_ops_only"] = enable_expand_ops_only; | |||
| json["disable_expand_ops"] = disable_expand_ops; | |||
| json["enable_cluster_ops"] = enable_cluster_ops; | |||
| json["enable_cluster_ops_only"] = enable_cluster_ops_only; | |||
| json["disable_cluster_ops"] = disable_cluster_ops; | |||
| json["enable_pass_only"] = enable_pass_only; | |||
| json["disable_pass"] = disable_pass; | |||
| return json.dump(); | |||
| } | |||
| } // namespace context | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,148 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H | |||
| #define MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace context { | |||
| class GraphKernelFlags { | |||
| public: | |||
| static const GraphKernelFlags &GetInstance() { | |||
| static std::unique_ptr<GraphKernelFlags> flags(nullptr); | |||
| auto contexts = GetGraphKernelContext(); | |||
| if (flags == nullptr || contexts.first != flags->flags_cache_ || contexts.second != flags->enable_cache_) { | |||
| flags.reset(new GraphKernelFlags(contexts.first, contexts.second)); | |||
| flags->Refresh(); | |||
| } | |||
| return *flags; | |||
| } | |||
| // Dump all flags to json-format string | |||
| std::string DumpAllFlags() const; | |||
| // Check whether graph_kernel is enabled | |||
| bool IsEnableGraphKernel() const { return opt_level > 0; } | |||
| GraphKernelFlags(const GraphKernelFlags &flags) = delete; | |||
| ~GraphKernelFlags() = default; | |||
| public: | |||
| /** | |||
| * dump_as_text, unsupported now. | |||
| */ | |||
| bool dump_as_text{false}; | |||
| /** | |||
| * opt_level, value from 0 to 3. | |||
| * 0: GraphKernel disabled | |||
| * 1: GraphKernel enabled | |||
| * 2 and 3 are not supported now. | |||
| * the default value is controlled by context `enable_graph_kernel`, | |||
| * but if it's also set in `graph_kernel_flags`, then the flag will prevail. | |||
| */ | |||
| unsigned int opt_level{0}; | |||
| /** | |||
| * auto_tune, unsupported now. | |||
| */ | |||
| unsigned int auto_tune{0}; | |||
| /** | |||
| * cluster_limit, unsupported now. | |||
| */ | |||
| unsigned int cluster_limit{30}; | |||
| /** | |||
| * Additional expanding operators (case sensitive). | |||
| * The operators to be added into the default expanding operator list. | |||
| */ | |||
| std::vector<std::string> enable_expand_ops; | |||
| /** | |||
| * Expanding operators to be enabled (case sensitive). | |||
| * Unlike the "enable_expand_ops", the default list will be overwritten by this list. | |||
| * Note that the "enable_expand_ops" and "disable_expand_ops" will be ignored if this flag is set. | |||
| */ | |||
| std::vector<std::string> enable_expand_ops_only; | |||
| /** | |||
| * Expanding operators to be disabled (case sensitive). | |||
| * The behavior is undefined when this list overlaps with "enable_expand_ops". | |||
| */ | |||
| std::vector<std::string> disable_expand_ops; | |||
| /** | |||
| * enable_cluster_ops, unsupported now. | |||
| */ | |||
| std::vector<std::string> enable_cluster_ops; | |||
| /** | |||
| * enable_cluster_ops_only, unsupported now. | |||
| */ | |||
| std::vector<std::string> enable_cluster_ops_only; | |||
| /** | |||
| * disable_cluster_ops, unsupported now. | |||
| */ | |||
| std::vector<std::string> disable_cluster_ops; | |||
| /** | |||
| * enable_pass_only, unsupported now. | |||
| */ | |||
| std::vector<std::string> enable_pass_only; | |||
| /** | |||
| * disable_pass, unsupported now. | |||
| */ | |||
| std::vector<std::string> disable_pass; | |||
| private: | |||
| GraphKernelFlags(const std::string &graph_kernel_flags, bool enable_graph_kernel) | |||
| : flags_cache_(graph_kernel_flags), enable_cache_(enable_graph_kernel) { | |||
| opt_level = enable_graph_kernel ? 1 : 0; | |||
| } | |||
| // get the `graph_kernel_flags` and `enable_graph_kernel` | |||
| static std::pair<std::string, bool> GetGraphKernelContext() { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // Use the environment variable in priority | |||
| auto env_flags = std::getenv("MS_GRAPH_KERNEL_FLAGS"); | |||
| std::string flags = env_flags ? std::string(env_flags) : context->get_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS); | |||
| return std::make_pair(flags, context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL)); | |||
| } | |||
| // parse and refresh the flags | |||
| void Refresh(); | |||
| // register the flags defined above | |||
| void RegisterFlags(std::map<std::string, std::string> *flag_map); | |||
| // cache the flag string to check whether the flags is changed. | |||
| std::string flags_cache_; | |||
| // cache the enable_graph_kernel value to check whether the context is changed. | |||
| bool enable_cache_; | |||
| }; | |||
| } // namespace context | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H | |||
| @@ -489,6 +489,7 @@ def _check_target_specific_cfgs(device, arg_key): | |||
| 'enable_dump': ['Ascend'], | |||
| 'save_dump_path': ['Ascend'], | |||
| 'enable_graph_kernel': ['Ascend', 'GPU'], | |||
| 'graph_kernel_flags': ['Ascend', 'GPU'], | |||
| 'enable_reduce_precision': ['Ascend'], | |||
| 'enable_profiling': ['Ascend'], | |||
| 'profiling_options': ['Ascend'], | |||
| @@ -513,7 +514,7 @@ def _check_target_specific_cfgs(device, arg_key): | |||
| save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, | |||
| enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, | |||
| enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, | |||
| enable_sparse=bool, max_call_depth=int, env_config_path=str) | |||
| enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str) | |||
| def set_context(**kwargs): | |||
| """ | |||
| Set context for running environment. | |||
| @@ -540,14 +541,14 @@ def set_context(**kwargs): | |||
| =========================== =========================== ================= | |||
| check_bprop print_file_path max_device_memory | |||
| device_id enable_dump enable_graph_kernel | |||
| device_target save_dump_path | |||
| device_target save_dump_path graph_kernel_flags | |||
| enable_sparse enable_graph_kernel | |||
| max_call_depth enable_reduce_precision | |||
| mode enable_profiling | |||
| reserve_class_name_in_scope profiling_options | |||
| save_graphs variable_memory_max_size | |||
| save_graphs_path auto_tune_mode | |||
| env_config_path | |||
| env_config_path graph_kernel_flags | |||
| grad_for_scalar | |||
| =========================== =========================== ================= | |||
| @@ -566,6 +567,7 @@ def set_context(**kwargs): | |||
| `context.set_context(save_graphs_path="path/to/ir/files"+device_id)`. | |||
| enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be | |||
| compiled into a fused kernel automatically. Default: False. | |||
| graph_kernel_flags (str): Set graph_kernel flags. | |||
| reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True. | |||
| enable_reduce_precision (bool): Whether to enable precision reduction. Default: True. | |||
| enable_dump (bool): Whether to enable dump. Default: False. | |||
| @@ -39,6 +39,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { | |||
| set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, "."); | |||
| set_param<std::string>(MS_CTX_ENV_CONFIG_PATH, ""); | |||
| set_param<std::string>(MS_CTX_TUNE_MODE, "NO_TUNE"); | |||
| set_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS, ""); | |||
| set_param<uint32_t>(MS_CTX_TSD_REF, 0); | |||
| set_param<uint32_t>(MS_CTX_GE_REF, 0); | |||
| @@ -112,6 +112,7 @@ enum MsCtxParam : unsigned { | |||
| MS_CTX_PYTHON_EXE_PATH, | |||
| MS_CTX_ENV_CONFIG_PATH, | |||
| MS_CTX_TUNE_MODE, | |||
| MS_CTX_GRAPH_KERNEL_FLAGS, | |||
| MS_CTX_TYPE_STRING_END, | |||
| // parameter numbers of each type | |||