| @@ -41,8 +41,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape | |||
| << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; | |||
| std::vector<int> lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())}; | |||
| const char *env_role = getenv(mindspore::parallel::ps::kEnvRole); | |||
| if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) { | |||
| if (mindspore::parallel::ps::Util::IsRoleOfWorker()) { | |||
| parallel::ps::Worker<float>::GetInstance().AddEmbeddingTable(key_, input_shape[axis]); | |||
| parallel::ps::Worker<float>::GetInstance().InitPSEmbeddingTable(keys, values, lens); | |||
| } | |||
| @@ -32,11 +32,6 @@ constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM"; | |||
| constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST"; | |||
| constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT"; | |||
| constexpr char kEnvRole[] = "MS_ROLE"; | |||
| constexpr char kEnvRoleOfPServer[] = "MS_PSERVER"; | |||
| constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; | |||
| constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; | |||
| constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE"; | |||
| constexpr char kDmlcInterface[] = "DMLC_INTERFACE"; | |||
| constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER"; | |||
| @@ -39,6 +39,7 @@ | |||
| #include "frontend/parallel/ps/optimizer_info.h" | |||
| #include "frontend/parallel/ps/optimizer_info_builder.h" | |||
| #include "frontend/parallel/ps/util.h" | |||
| #include "frontend/parallel/ps/ps_context.h" | |||
| #include "runtime/device/cpu/kernel_select_cpu.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| @@ -741,7 +742,7 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) { | |||
| return; | |||
| } | |||
| Init(func_graph); | |||
| Util::SetRankId(rank_id_); | |||
| PSContext::instance()->SetPSRankId(rank_id_); | |||
| thread_->join(); | |||
| ::ps::Finalize(0, true); | |||
| } | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/parallel/ps/ps_context.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| namespace ps { | |||
| std::shared_ptr<PSContext> PSContext::instance() { | |||
| static std::shared_ptr<PSContext> ps_instance = nullptr; | |||
| if (ps_instance == nullptr) { | |||
| ps_instance.reset(new (std::nothrow) PSContext()); | |||
| } | |||
| return ps_instance; | |||
| } | |||
| void PSContext::SetPSEnable(bool enabled) { | |||
| ps_enabled_ = enabled; | |||
| if (ps_enabled_) { | |||
| std::string ms_role = common::GetEnv(kEnvRole); | |||
| MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role; | |||
| if (ms_role == kEnvRoleOfWorker) { | |||
| is_worker_ = true; | |||
| } else if (ms_role == kEnvRoleOfPServer) { | |||
| is_pserver_ = true; | |||
| } else if (ms_role == kEnvRoleOfScheduler) { | |||
| is_sched_ = true; | |||
| } else { | |||
| MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid."; | |||
| } | |||
| } else { | |||
| MS_LOG(INFO) << "PS mode is disabled."; | |||
| is_worker_ = false; | |||
| is_pserver_ = false; | |||
| is_sched_ = false; | |||
| } | |||
| } | |||
| bool PSContext::is_ps_enabled() const { return ps_enabled_; } | |||
| void PSContext::Reset() { | |||
| ps_enabled_ = false; | |||
| is_worker_ = false; | |||
| is_pserver_ = false; | |||
| is_sched_ = false; | |||
| } | |||
| std::string PSContext::ms_role() const { | |||
| if (is_worker_) { | |||
| return kEnvRoleOfWorker; | |||
| } else if (is_pserver_) { | |||
| return kEnvRoleOfPServer; | |||
| } else if (is_sched_) { | |||
| return kEnvRoleOfScheduler; | |||
| } else { | |||
| return kEnvRoleOfNotPS; | |||
| } | |||
| } | |||
| bool PSContext::is_role_worker() const { return is_worker_; } | |||
| bool PSContext::is_role_pserver() const { return is_pserver_; } | |||
| bool PSContext::is_role_sched() const { return is_sched_; } | |||
| void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; } | |||
| int PSContext::ps_rank_id() const { return rank_id_; } | |||
| } // namespace ps | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| namespace ps { | |||
| constexpr char kEnvRole[] = "MS_ROLE"; | |||
| constexpr char kEnvRoleOfPServer[] = "MS_PSERVER"; | |||
| constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; | |||
| constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; | |||
| constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS"; | |||
| class PSContext { | |||
| public: | |||
| ~PSContext() = default; | |||
| PSContext(PSContext const &) = delete; | |||
| PSContext &operator=(const PSContext &) = delete; | |||
| static std::shared_ptr<PSContext> instance(); | |||
| void SetPSEnable(bool enabled); | |||
| bool is_ps_enabled() const; | |||
| void Reset(); | |||
| std::string ms_role() const; | |||
| bool is_role_worker() const; | |||
| bool is_role_pserver() const; | |||
| bool is_role_sched() const; | |||
| void SetPSRankId(int rank_id); | |||
| int ps_rank_id() const; | |||
| private: | |||
| PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {} | |||
| bool ps_enabled_; | |||
| bool is_worker_; | |||
| bool is_pserver_; | |||
| bool is_sched_; | |||
| int rank_id_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_ | |||
| @@ -16,7 +16,9 @@ | |||
| #include "frontend/parallel/ps/util.h" | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "frontend/parallel/ps/common.h" | |||
| #include "frontend/parallel/ps/ps_context.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| @@ -45,34 +47,13 @@ std::unordered_map<int, std::string> Util::id_to_optimizer_nodes{ | |||
| {3, kSparseFtrlOp}, | |||
| }; | |||
| bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); } | |||
| bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_enabled(); } | |||
| bool Util::IsRoleOfWorker() { | |||
| auto role = common::GetEnv(kEnvRole); | |||
| if (strcmp(role.c_str(), kEnvRoleOfWorker) == 0) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool Util::IsRoleOfWorker() { return PSContext::instance()->is_role_worker(); } | |||
| bool Util::IsRoleOfPServer() { | |||
| auto role = common::GetEnv(kEnvRole); | |||
| if (strcmp(role.c_str(), kEnvRoleOfPServer) == 0) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool Util::IsRoleOfPServer() { return PSContext::instance()->is_role_pserver(); } | |||
| bool Util::IsRoleOfScheduler() { | |||
| auto role = common::GetEnv(kEnvRole); | |||
| if (strcmp(role.c_str(), kEnvRoleOfScheduler) == 0) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_role_sched(); } | |||
| void Util::SetInternalEnvVar() { | |||
| if (IsParamServerMode()) { | |||
| @@ -163,10 +144,6 @@ std::map<int, int> Util::AllRankLocalShard(int first_dim, int rank_id, int serve | |||
| return shard_dims; | |||
| } | |||
| void Util::SetRankId(int rank_id) { rank_id_ = rank_id; } | |||
| int Util::GetRankId() { return rank_id_; } | |||
| void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size, | |||
| const size_t first_dim_size, const size_t outer_dim_size, | |||
| mindspore::kernel::SparseGradient<int> *unique_sparse_grad) { | |||
| @@ -40,8 +40,6 @@ class Util { | |||
| static bool is_optimizer(std::string name); | |||
| static int LocalShard(int first_dim, int rank_id, int server_num); | |||
| static std::map<int, int> AllRankLocalShard(int first_dim, int rank_id, int server_num); | |||
| static void SetRankId(int rank_id); | |||
| static int GetRankId(); | |||
| static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size, | |||
| const size_t first_dim_size, const size_t outer_dim_size, | |||
| mindspore::kernel::SparseGradient<int> *unique_sparse_grad); | |||
| @@ -27,6 +27,7 @@ | |||
| #include "ps/ps.h" | |||
| #include "frontend/parallel/ps/util.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "frontend/parallel/ps/ps_context.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -43,7 +44,7 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||
| explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id, int general_customer_id) | |||
| : Worker(app_id, customer_id) { | |||
| server_num_ = ::ps::NumServers(); | |||
| Util::SetRankId(::ps::MyRank()); | |||
| PSContext::instance()->SetPSRankId(::ps::MyRank()); | |||
| using std::placeholders::_1; | |||
| using std::placeholders::_2; | |||
| using std::placeholders::_3; | |||
| @@ -36,6 +36,7 @@ | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "frontend/parallel/ps/util.h" | |||
| #endif | |||
| #include "frontend/parallel/ps/ps_context.h" | |||
| namespace py = pybind11; | |||
| using EnvInstance = mindspore::EnvInstance; | |||
| @@ -49,6 +50,7 @@ using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy; | |||
| using ParallelContext = mindspore::parallel::ParallelContext; | |||
| using CostModelContext = mindspore::parallel::CostModelContext; | |||
| using mindspore::MsCtxParam; | |||
| using PSContext = mindspore::parallel::ps::PSContext; | |||
| // Interface with python | |||
| PYBIND11_MODULE(_c_expression, m) { | |||
| @@ -276,9 +278,15 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| "Finalize gpu collective communication mode."); | |||
| #endif | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| (void)m.def("get_ps_mode_rank", &mindspore::parallel::ps::Util::GetRankId, "Get Worker and PServer rank id."); | |||
| #endif | |||
| (void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext") | |||
| .def_static("get_instance", &PSContext::instance, "Get PS context instance.") | |||
| .def("set_ps_enable", &PSContext::SetPSEnable, "Set PS mode enabled or disabled.") | |||
| .def("is_ps_enabled", &PSContext::is_ps_enabled, "Get PS mode enable-disable status.") | |||
| .def("reset", &PSContext::Reset, "Reset PS context attributes.") | |||
| .def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.") | |||
| .def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.") | |||
| .def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.") | |||
| .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id."); | |||
| (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | |||
| .def(py::init()) | |||
| @@ -15,7 +15,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Providing interface methods.""" | |||
| import os | |||
| import types | |||
| from collections import OrderedDict | |||
| from functools import wraps | |||
| @@ -25,6 +24,7 @@ from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, Pynativ | |||
| from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend | |||
| from .tensor import Tensor as MsTensor | |||
| from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor | |||
| from ..parallel._ps_context import _is_role_pserver | |||
| # store ms_function class compiled pipeline cache | |||
| ms_compile_cache = {} | |||
| @@ -469,7 +469,7 @@ class _Executor: | |||
| return self._executor.has_compiled(phase) | |||
| def __call__(self, obj, *args, phase='predict'): | |||
| if context.get_context("precompile_only") or os.getenv("MS_ROLE") == "MS_PSERVER": | |||
| if context.get_context("precompile_only") or _is_role_pserver(): | |||
| return None | |||
| return self.run(obj, *args, phase=phase) | |||
| @@ -22,6 +22,7 @@ from .tensor import Tensor, MetaTensor | |||
| from .._checkparam import _check_str_by_regular | |||
| from ..parallel._tensor import _get_slice_index | |||
| from ..parallel._auto_parallel_context import auto_parallel_context | |||
| from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched | |||
| __all__ = ['Parameter', 'ParameterTuple'] | |||
| @@ -168,8 +169,13 @@ class Parameter(MetaTensor): | |||
| """For parse check.""" | |||
| def set_param_ps(self, init_in_server=False): | |||
| self.is_param_ps = True | |||
| self.init_in_server = init_in_server | |||
| if _is_role_worker() or _is_role_pserver() or _is_role_sched(): | |||
| self.is_param_ps = True | |||
| self.init_in_server = init_in_server | |||
| else: | |||
| raise RuntimeError("Must complete following two steps before calling set_param_ps: \ | |||
| 1. set_ps_context(enable_ps=True) \ | |||
| 2. export MS_ROLE environment variable.") | |||
| @property | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================ | |||
| """comm_helper""" | |||
| import os | |||
| from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched | |||
| from ._hccl_management import load_lib as hccl_load_lib | |||
| _HCCL_AVAILABLE = False | |||
| @@ -44,7 +44,6 @@ else: | |||
| HCCL_WORLD_COMM_GROUP = "hccl_world_group" | |||
| NCCL_WORLD_COMM_GROUP = "nccl_world_group" | |||
| MS_ROLE = os.getenv("MS_ROLE") | |||
| class Backend: | |||
| """ | |||
| @@ -113,7 +112,7 @@ def check_parameter_available(func): | |||
| Wrapper. If not available, raise Error. | |||
| """ | |||
| def wrapper(*args, **kargs): | |||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | |||
| if _is_role_pserver() or _is_role_sched(): | |||
| return func(*args, **kargs) | |||
| group = None | |||
| if "group" in kargs.keys(): | |||
| @@ -154,7 +153,7 @@ def _get_rank_helper(group, backend): | |||
| Integer. The local rank id of the calling process. | |||
| """ | |||
| rank_id = None | |||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | |||
| if _is_role_pserver() or _is_role_sched(): | |||
| rank_id = 0 | |||
| return rank_id | |||
| if backend == Backend.HCCL: | |||
| @@ -213,7 +212,7 @@ def _get_size_helper(group, backend): | |||
| Integer. The rank size of specified group. | |||
| """ | |||
| size = None | |||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | |||
| if _is_role_pserver() or _is_role_sched(): | |||
| size = 1 | |||
| return size | |||
| if backend == Backend.HCCL: | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Communication management API""" | |||
| import os | |||
| from mindspore import context | |||
| from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched | |||
| from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ | |||
| _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ | |||
| _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ | |||
| @@ -29,7 +29,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", | |||
| DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP | |||
| DEFAULT_BACKEND = Backend("hccl") | |||
| MS_ROLE = os.getenv("MS_ROLE") | |||
| def _get_group(group): | |||
| @@ -61,7 +60,7 @@ def init(backend_name=None): | |||
| RuntimeError: If device target is invalid. | |||
| RuntimeError: If backend is invalid or distributed init fails. | |||
| """ | |||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | |||
| if _is_role_pserver() or _is_role_sched(): | |||
| return | |||
| if backend_name is None: | |||
| device_target = context.get_context("device_target") | |||
| @@ -26,9 +26,11 @@ from mindspore._c_expression import MSContext, ms_ctx_param | |||
| from mindspore._checkparam import args_type_check | |||
| from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ | |||
| _reset_auto_parallel_context | |||
| from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context | |||
| __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context', | |||
| 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode'] | |||
| 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 'set_ps_context', | |||
| 'get_ps_context', 'reset_ps_context'] | |||
| GRAPH_MODE = 0 | |||
| PYNATIVE_MODE = 1 | |||
| @@ -569,3 +571,58 @@ class ParallelMode: | |||
| SEMI_AUTO_PARALLEL = "semi_auto_parallel" | |||
| AUTO_PARALLEL = "auto_parallel" | |||
| MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL] | |||
| @args_type_check(enable_ps=bool) | |||
| def set_ps_context(**kwargs): | |||
| """ | |||
| Set parameter server training mode context. | |||
| Note: | |||
| Some other environment variables should also be set for parameter server training mode. | |||
| These environment variables are listed below: | |||
| MS_SERVER_NUM # Server number | |||
| MS_WORKER_NUM # Worker number | |||
| MS_SCHED_HOST # Scheduler IP address | |||
| MS_SCHED_PORT # Scheduler port | |||
| MS_ROLE # The role of this process: | |||
| MS_SCHED represents the scheduler, | |||
| MS_WORKER represents the worker, | |||
| MS_PSERVER represents the Server | |||
| Args: | |||
| enable_ps (bool): Whether to enable parameter server training mode. | |||
| Only after enable_ps is set True, the environment variables will be effective. | |||
| Default: False. | |||
| Raises: | |||
| ValueError: If input key is not the attribute in parameter server training mode context. | |||
| Examples: | |||
| >>> context.set_ps_context(enable_ps=True) | |||
| """ | |||
| _set_ps_context(**kwargs) | |||
| def get_ps_context(attr_key): | |||
| """ | |||
| Get parameter server training mode context attribute value according to the key. | |||
| Args: | |||
| attr_key (str): The key of the attribute. | |||
| Returns: | |||
| Returns attribute value according to the key. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| """ | |||
| return _get_ps_context(attr_key) | |||
| def reset_ps_context(): | |||
| """ | |||
| Reset parameter server training mode context attributes to the default values: | |||
| - enable_ps: False. | |||
| """ | |||
| _reset_ps_context() | |||
| @@ -0,0 +1,115 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Context for parameter server training mode""" | |||
| from mindspore._c_expression import PSContext | |||
| _ps_context = None | |||
| def ps_context(): | |||
| """ | |||
| Get the global _ps_context, if it is not created, create a new one. | |||
| Returns: | |||
| _ps_context, the global parameter server training mode context. | |||
| """ | |||
| global _ps_context | |||
| if _ps_context is None: | |||
| _ps_context = PSContext.get_instance() | |||
| return _ps_context | |||
| _set_ps_context_func_map = { | |||
| "enable_ps": ps_context().set_ps_enable | |||
| } | |||
| _get_ps_context_func_map = { | |||
| "enable_ps": ps_context().is_ps_enabled | |||
| } | |||
| def _get_ps_mode_rank(): | |||
| ps_rank = ps_context().ps_rank_id() | |||
| if ps_rank == -1: | |||
| raise RuntimeError("The parameter server mode training is not enabled yet.") | |||
| return ps_rank | |||
| def _set_ps_context(**kwargs): | |||
| """ | |||
| Set parameter server training mode context. | |||
| Note: | |||
| Some other environment variables should also be set for parameter server training mode. | |||
| These environment variables are listed below: | |||
| MS_SERVER_NUM # Server number | |||
| MS_WORKER_NUM # Worker number | |||
| MS_SCHED_HOST # Scheduler IP address | |||
| MS_SCHED_PORT # Scheduler port | |||
| MS_ROLE # The role of this process: | |||
| MS_SCHED represents the scheduler, | |||
| MS_WORKER represents the worker, | |||
| MS_PSERVER represents the Server | |||
| Args: | |||
| enable_ps (bool): Whether to enable parameter server training mode. | |||
| Only after enable_ps is set True, the environment variables will be effective. | |||
| Default: False. | |||
| Raises: | |||
| ValueError: If input key is not the attribute in parameter server training mode context. | |||
| Examples: | |||
| >>> context.set_ps_context(enable_ps=True) | |||
| """ | |||
| for key, value in kwargs.items(): | |||
| if key not in _set_ps_context_func_map: | |||
| raise ValueError("Set PS context keyword %s is not recognized!" % key) | |||
| set_func = _set_ps_context_func_map[key] | |||
| set_func(value) | |||
| def _get_ps_context(attr_key): | |||
| """ | |||
| Get parameter server training mode context attribute value according to the key. | |||
| Args: | |||
| attr_key (str): The key of the attribute. | |||
| Returns: | |||
| Returns attribute value according to the key. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| """ | |||
| if key not in _get_ps_context_func_map: | |||
| raise ValueError("Get PS context keyword %s is not recognized!" % key) | |||
| get_func = _get_ps_context_func_map[attr_key] | |||
| get_func(attr_key) | |||
| def _reset_ps_context(): | |||
| """ | |||
| Reset parameter server training mode context attributes to the default values: | |||
| - enable_ps: False. | |||
| """ | |||
| ps_context().reset() | |||
| def _is_role_worker(): | |||
| return ps_context().is_role_worker() | |||
| def _is_role_pserver(): | |||
| return ps_context().is_role_pserver() | |||
| def _is_role_sched(): | |||
| return ps_context().is_role_sched() | |||
| @@ -1,23 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Utils for parameter server training mode""" | |||
| from mindspore._c_expression import get_ps_mode_rank | |||
| def _get_ps_mode_rank(): | |||
| ps_rank = get_ps_mode_rank() | |||
| if ps_rank == -1: | |||
| raise RuntimeError("The parameter server mode training is not launched yet.") | |||
| return ps_rank | |||
| @@ -24,6 +24,7 @@ from mindspore import log as logger | |||
| from mindspore._checkparam import check_bool, check_int_non_negative | |||
| from mindspore.train._utils import _make_directory | |||
| from mindspore.train.serialization import save_checkpoint, _save_graph | |||
| from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank | |||
| from ._callback import Callback, set_cur_net | |||
| @@ -280,8 +281,7 @@ class ModelCheckpoint(Callback): | |||
| if save_ckpt: | |||
| cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ | |||
| + str(step_num_in_epoch) + ".ckpt" | |||
| if os.getenv("MS_ROLE") == "MS_PSERVER": | |||
| from mindspore.parallel._ps_utils import _get_ps_mode_rank | |||
| if _is_role_pserver(): | |||
| cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file | |||
| # update checkpoint file list. | |||
| self._manager.update_ckpoint_filelist(self._directory, self._prefix) | |||
| @@ -27,6 +27,7 @@ from .callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from .. import context | |||
| from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | |||
| from ..parallel._ps_context import _is_role_pserver, _is_role_sched | |||
| from ..nn.metrics import Loss | |||
| from .. import nn | |||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| @@ -378,8 +379,7 @@ class Model: | |||
| cb_params.list_callback = self._transform_callbacks(callbacks) | |||
| cb_params.train_dataset_element = None | |||
| cb_params.network = self._network | |||
| ms_role = os.getenv("MS_ROLE") | |||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||
| if _is_role_pserver() or _is_role_sched(): | |||
| epoch = 1 | |||
| # build callback list | |||
| @@ -516,7 +516,7 @@ class Model: | |||
| self._loss_scale_manager.update_loss_scale(overflow) | |||
| list_callback.step_end(run_context) | |||
| if os.getenv("MS_ROLE") == "MS_PSERVER": | |||
| if _is_role_pserver(): | |||
| os._exit(0) | |||
| should_stop = should_stop or run_context.get_stop_requested() | |||
| if should_stop: | |||
| @@ -70,6 +70,7 @@ if __name__ == '__main__': | |||
| # init context | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) | |||
| context.set_ps_context(enable_ps=True) | |||
| if args_opt.run_distribute: | |||
| if target == "Ascend": | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| """Model.""" | |||
| import math | |||
| import os | |||
| from collections.abc import Iterable | |||
| import numpy as np | |||
| @@ -405,9 +404,6 @@ class Model: | |||
| cb_params.list_callback = self._transform_callbacks(callbacks) | |||
| cb_params.train_dataset_element = None | |||
| cb_params.network = self._network | |||
| ms_role = os.getenv("MS_ROLE") | |||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||
| epoch = 1 | |||
| # build callback list | |||
| with _CallbackManager(callbacks) as list_callback: | |||
| @@ -118,6 +118,7 @@ if __name__ == "__main__": | |||
| wide_deep_config.argparse_init() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) | |||
| context.set_ps_context(enable_ps=True) | |||
| init() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | |||
| device_num=get_group_size()) | |||
| @@ -26,6 +26,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Adam | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker | |||
| parser = argparse.ArgumentParser(description="test_sparse_embedding") | |||
| parser.add_argument("--device_target", type=str, default="Ascend") | |||
| @@ -34,6 +35,7 @@ device_target = args.device_target | |||
| context.set_context( | |||
| mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True | |||
| ) | |||
| context.set_ps_context(enable_ps=True) | |||
| def fc_with_initialize(input_channels, out_channels): | |||
| @@ -81,7 +83,7 @@ def do_sparse_embedding(ps=False): | |||
| for _ in range(epoch): | |||
| data = Tensor(np.random.randint(0, 15, (32, 3), np.int32)) | |||
| label = Tensor(np.random.randint(0, 9, (32), np.int32)) | |||
| if envs.get("MS_ROLE") == "MS_PSERVER": | |||
| if _is_role_pserver(): | |||
| train_network(data, label) | |||
| sys.exit() | |||
| else: | |||
| @@ -96,10 +98,10 @@ if __name__ == "__main__": | |||
| np.random.seed(0) | |||
| ps_loss = do_sparse_embedding(True) | |||
| if envs.get("MS_ROLE") == "MS_WORKER": | |||
| envs["MS_ROLE"] = "" | |||
| if _is_role_worker(): | |||
| context.reset_ps_context() | |||
| np.random.seed(0) | |||
| no_ps_loss = do_sparse_embedding() | |||
| envs["MS_ROLE"] = "MS_WORKER" | |||
| context.set_ps_context(enable_ps=True) | |||
| assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6) | |||
| @@ -35,6 +35,7 @@ args, _ = parser.parse_known_args() | |||
| device_target = args.device_target | |||
| dataset_path = args.dataset_path | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=device_target) | |||
| context.set_ps_context(enable_ps=True) | |||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
| """weight initial for conv layer""" | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import sys | |||
| import argparse | |||
| import numpy as np | |||
| @@ -22,6 +23,7 @@ from mindspore.common.initializer import TruncatedNormal | |||
| from mindspore import Tensor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.communication.management import init, get_group_size | |||
| from mindspore.parallel._ps_context import _is_role_pserver | |||
| # from resnet import resnet50 | |||
| parser = argparse.ArgumentParser(description="test_ps_lenet") | |||
| @@ -29,6 +31,7 @@ parser.add_argument("--device_target", type=str, default="Ascend") | |||
| args, _ = parser.parse_known_args() | |||
| device_target = args.device_target | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=device_target) | |||
| context.set_ps_context(enable_ps=True) | |||
| if device_target == "GPU": | |||
| init() | |||
| @@ -106,6 +109,10 @@ if __name__ == "__main__": | |||
| for _ in range(epoch): | |||
| data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) | |||
| label = Tensor(np.random.randint(0, 9, (32)).astype(np.int32)) | |||
| loss = train_network(data, label).asnumpy() | |||
| losses.append(loss) | |||
| if _is_role_pserver(): | |||
| train_network(data, label) | |||
| sys.exit() | |||
| else: | |||
| loss = train_network(data, label).asnumpy() | |||
| losses.append(loss) | |||
| print(losses) | |||