Browse Source

Add PS context.

tags/v1.0.0
ZPaC 5 years ago
parent
commit
87bf2a7dcd
24 changed files with 381 additions and 94 deletions
  1. +1
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc
  2. +0
    -5
      mindspore/ccsrc/frontend/parallel/ps/common.h
  3. +2
    -1
      mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
  4. +86
    -0
      mindspore/ccsrc/frontend/parallel/ps/ps_context.cc
  5. +61
    -0
      mindspore/ccsrc/frontend/parallel/ps/ps_context.h
  6. +6
    -29
      mindspore/ccsrc/frontend/parallel/ps/util.cc
  7. +0
    -2
      mindspore/ccsrc/frontend/parallel/ps/util.h
  8. +2
    -1
      mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
  9. +11
    -3
      mindspore/ccsrc/pipeline/jit/init.cc
  10. +2
    -2
      mindspore/common/api.py
  11. +8
    -2
      mindspore/common/parameter.py
  12. +4
    -5
      mindspore/communication/_comm_helper.py
  13. +2
    -3
      mindspore/communication/management.py
  14. +58
    -1
      mindspore/context.py
  15. +115
    -0
      mindspore/parallel/_ps_context.py
  16. +0
    -23
      mindspore/parallel/_ps_utils.py
  17. +2
    -2
      mindspore/train/callback/_checkpoint.py
  18. +3
    -3
      mindspore/train/model.py
  19. +1
    -0
      model_zoo/official/cv/resnet/train.py
  20. +0
    -4
      model_zoo/official/nlp/bert_thor/src/model_thor.py
  21. +1
    -0
      model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py
  22. +6
    -4
      tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py
  23. +1
    -0
      tests/st/ps/full_ps/test_full_ps_lenet.py
  24. +9
    -2
      tests/st/ps/multi_full_ps/test_multi_full_ps.py

+ 1
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc View File

@@ -41,8 +41,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape
<< ", indices_shape:" << indices_shape << ", output_shape:" << output_shape;
std::vector<int> lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())};
const char *env_role = getenv(mindspore::parallel::ps::kEnvRole);
if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) {
if (mindspore::parallel::ps::Util::IsRoleOfWorker()) {
parallel::ps::Worker<float>::GetInstance().AddEmbeddingTable(key_, input_shape[axis]);
parallel::ps::Worker<float>::GetInstance().InitPSEmbeddingTable(keys, values, lens);
}


+ 0
- 5
mindspore/ccsrc/frontend/parallel/ps/common.h View File

@@ -32,11 +32,6 @@ constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST";
constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT";

constexpr char kEnvRole[] = "MS_ROLE";
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";

constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE";
constexpr char kDmlcInterface[] = "DMLC_INTERFACE";
constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER";


+ 2
- 1
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h View File

@@ -39,6 +39,7 @@
#include "frontend/parallel/ps/optimizer_info.h"
#include "frontend/parallel/ps/optimizer_info_builder.h"
#include "frontend/parallel/ps/util.h"
#include "frontend/parallel/ps/ps_context.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "utils/ms_context.h"
#include "backend/kernel_compiler/kernel.h"
@@ -741,7 +742,7 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
return;
}
Init(func_graph);
Util::SetRankId(rank_id_);
PSContext::instance()->SetPSRankId(rank_id_);
thread_->join();
::ps::Finalize(0, true);
}


+ 86
- 0
mindspore/ccsrc/frontend/parallel/ps/ps_context.cc View File

@@ -0,0 +1,86 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "frontend/parallel/ps/ps_context.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"

namespace mindspore {
namespace parallel {
namespace ps {
std::shared_ptr<PSContext> PSContext::instance() {
static std::shared_ptr<PSContext> ps_instance = nullptr;
if (ps_instance == nullptr) {
ps_instance.reset(new (std::nothrow) PSContext());
}
return ps_instance;
}

void PSContext::SetPSEnable(bool enabled) {
ps_enabled_ = enabled;
if (ps_enabled_) {
std::string ms_role = common::GetEnv(kEnvRole);
MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
if (ms_role == kEnvRoleOfWorker) {
is_worker_ = true;
} else if (ms_role == kEnvRoleOfPServer) {
is_pserver_ = true;
} else if (ms_role == kEnvRoleOfScheduler) {
is_sched_ = true;
} else {
MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid.";
}
} else {
MS_LOG(INFO) << "PS mode is disabled.";
is_worker_ = false;
is_pserver_ = false;
is_sched_ = false;
}
}

bool PSContext::is_ps_enabled() const { return ps_enabled_; }

void PSContext::Reset() {
ps_enabled_ = false;
is_worker_ = false;
is_pserver_ = false;
is_sched_ = false;
}

std::string PSContext::ms_role() const {
if (is_worker_) {
return kEnvRoleOfWorker;
} else if (is_pserver_) {
return kEnvRoleOfPServer;
} else if (is_sched_) {
return kEnvRoleOfScheduler;
} else {
return kEnvRoleOfNotPS;
}
}

bool PSContext::is_role_worker() const { return is_worker_; }

bool PSContext::is_role_pserver() const { return is_pserver_; }

bool PSContext::is_role_sched() const { return is_sched_; }

void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; }

int PSContext::ps_rank_id() const { return rank_id_; }
} // namespace ps
} // namespace parallel
} // namespace mindspore

+ 61
- 0
mindspore/ccsrc/frontend/parallel/ps/ps_context.h View File

@@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_

#include <string>
#include <memory>

namespace mindspore {
namespace parallel {
namespace ps {
constexpr char kEnvRole[] = "MS_ROLE";
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";

class PSContext {
public:
~PSContext() = default;
PSContext(PSContext const &) = delete;
PSContext &operator=(const PSContext &) = delete;
static std::shared_ptr<PSContext> instance();

void SetPSEnable(bool enabled);
bool is_ps_enabled() const;
void Reset();
std::string ms_role() const;
bool is_role_worker() const;
bool is_role_pserver() const;
bool is_role_sched() const;
void SetPSRankId(int rank_id);
int ps_rank_id() const;

private:
PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
bool is_sched_;
int rank_id_;
};
} // namespace ps
} // namespace parallel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_

+ 6
- 29
mindspore/ccsrc/frontend/parallel/ps/util.cc View File

@@ -16,7 +16,9 @@

#include "frontend/parallel/ps/util.h"
#include <unordered_map>
#include <vector>
#include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/ps_context.h"
#include "utils/ms_utils.h"

namespace mindspore {
@@ -45,34 +47,13 @@ std::unordered_map<int, std::string> Util::id_to_optimizer_nodes{
{3, kSparseFtrlOp},
};

bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); }
bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_enabled(); }

bool Util::IsRoleOfWorker() {
auto role = common::GetEnv(kEnvRole);
if (strcmp(role.c_str(), kEnvRoleOfWorker) == 0) {
return true;
} else {
return false;
}
}
bool Util::IsRoleOfWorker() { return PSContext::instance()->is_role_worker(); }

bool Util::IsRoleOfPServer() {
auto role = common::GetEnv(kEnvRole);
if (strcmp(role.c_str(), kEnvRoleOfPServer) == 0) {
return true;
} else {
return false;
}
}
bool Util::IsRoleOfPServer() { return PSContext::instance()->is_role_pserver(); }

bool Util::IsRoleOfScheduler() {
auto role = common::GetEnv(kEnvRole);
if (strcmp(role.c_str(), kEnvRoleOfScheduler) == 0) {
return true;
} else {
return false;
}
}
bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_role_sched(); }

void Util::SetInternalEnvVar() {
if (IsParamServerMode()) {
@@ -163,10 +144,6 @@ std::map<int, int> Util::AllRankLocalShard(int first_dim, int rank_id, int serve
return shard_dims;
}

void Util::SetRankId(int rank_id) { rank_id_ = rank_id; }

int Util::GetRankId() { return rank_id_; }

void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
const size_t first_dim_size, const size_t outer_dim_size,
mindspore::kernel::SparseGradient<int> *unique_sparse_grad) {


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ps/util.h View File

@@ -40,8 +40,6 @@ class Util {
static bool is_optimizer(std::string name);
static int LocalShard(int first_dim, int rank_id, int server_num);
static std::map<int, int> AllRankLocalShard(int first_dim, int rank_id, int server_num);
static void SetRankId(int rank_id);
static int GetRankId();
static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
const size_t first_dim_size, const size_t outer_dim_size,
mindspore::kernel::SparseGradient<int> *unique_sparse_grad);


+ 2
- 1
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h View File

@@ -27,6 +27,7 @@
#include "ps/ps.h"
#include "frontend/parallel/ps/util.h"
#include "backend/kernel_compiler/common_utils.h"
#include "frontend/parallel/ps/ps_context.h"

namespace mindspore {
namespace parallel {
@@ -43,7 +44,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id, int general_customer_id)
: Worker(app_id, customer_id) {
server_num_ = ::ps::NumServers();
Util::SetRankId(::ps::MyRank());
PSContext::instance()->SetPSRankId(::ps::MyRank());
using std::placeholders::_1;
using std::placeholders::_2;
using std::placeholders::_3;


+ 11
- 3
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -36,6 +36,7 @@
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "frontend/parallel/ps/util.h"
#endif
#include "frontend/parallel/ps/ps_context.h"
namespace py = pybind11;

using EnvInstance = mindspore::EnvInstance;
@@ -49,6 +50,7 @@ using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy;
using ParallelContext = mindspore::parallel::ParallelContext;
using CostModelContext = mindspore::parallel::CostModelContext;
using mindspore::MsCtxParam;
using PSContext = mindspore::parallel::ps::PSContext;

// Interface with python
PYBIND11_MODULE(_c_expression, m) {
@@ -276,9 +278,15 @@ PYBIND11_MODULE(_c_expression, m) {
"Finalize gpu collective communication mode.");
#endif

#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
(void)m.def("get_ps_mode_rank", &mindspore::parallel::ps::Util::GetRankId, "Get Worker and PServer rank id.");
#endif
(void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext")
.def_static("get_instance", &PSContext::instance, "Get PS context instance.")
.def("set_ps_enable", &PSContext::SetPSEnable, "Set PS mode enabled or disabled.")
.def("is_ps_enabled", &PSContext::is_ps_enabled, "Get PS mode enable-disable status.")
.def("reset", &PSContext::Reset, "Reset PS context attributes.")
.def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.")
.def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.")
.def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.")
.def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.");

(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
.def(py::init())


+ 2
- 2
mindspore/common/api.py View File

@@ -15,7 +15,6 @@
# limitations under the License.
# ============================================================================
"""Providing interface methods."""
import os
import types
from collections import OrderedDict
from functools import wraps
@@ -25,6 +24,7 @@ from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, Pynativ
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
from .tensor import Tensor as MsTensor
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor
from ..parallel._ps_context import _is_role_pserver
# store ms_function class compiled pipeline cache
ms_compile_cache = {}

@@ -469,7 +469,7 @@ class _Executor:
return self._executor.has_compiled(phase)

def __call__(self, obj, *args, phase='predict'):
if context.get_context("precompile_only") or os.getenv("MS_ROLE") == "MS_PSERVER":
if context.get_context("precompile_only") or _is_role_pserver():
return None
return self.run(obj, *args, phase=phase)



+ 8
- 2
mindspore/common/parameter.py View File

@@ -22,6 +22,7 @@ from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular
from ..parallel._tensor import _get_slice_index
from ..parallel._auto_parallel_context import auto_parallel_context
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched

__all__ = ['Parameter', 'ParameterTuple']

@@ -168,8 +169,13 @@ class Parameter(MetaTensor):
"""For parse check."""

def set_param_ps(self, init_in_server=False):
self.is_param_ps = True
self.init_in_server = init_in_server
if _is_role_worker() or _is_role_pserver() or _is_role_sched():
self.is_param_ps = True
self.init_in_server = init_in_server
else:
raise RuntimeError("Must complete following two steps before calling set_param_ps: \
1. set_ps_context(enable_ps=True) \
2. export MS_ROLE environment variable.")


@property


+ 4
- 5
mindspore/communication/_comm_helper.py View File

@@ -14,7 +14,7 @@
# ============================================================================
"""comm_helper"""

import os
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from ._hccl_management import load_lib as hccl_load_lib

_HCCL_AVAILABLE = False
@@ -44,7 +44,6 @@ else:

HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
MS_ROLE = os.getenv("MS_ROLE")

class Backend:
"""
@@ -113,7 +112,7 @@ def check_parameter_available(func):
Wrapper. If not available, raise Error.
"""
def wrapper(*args, **kargs):
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
return func(*args, **kargs)
group = None
if "group" in kargs.keys():
@@ -154,7 +153,7 @@ def _get_rank_helper(group, backend):
Integer. The local rank id of the calling process.
"""
rank_id = None
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
rank_id = 0
return rank_id
if backend == Backend.HCCL:
@@ -213,7 +212,7 @@ def _get_size_helper(group, backend):
Integer. The rank size of specified group.
"""
size = None
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
size = 1
return size
if backend == Backend.HCCL:


+ 2
- 3
mindspore/communication/management.py View File

@@ -13,8 +13,8 @@
# limitations under the License.
# ============================================================================
"""Communication management API"""
import os
from mindspore import context
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
@@ -29,7 +29,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",

DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
DEFAULT_BACKEND = Backend("hccl")
MS_ROLE = os.getenv("MS_ROLE")


def _get_group(group):
@@ -61,7 +60,7 @@ def init(backend_name=None):
RuntimeError: If device target is invalid.
RuntimeError: If backend is invalid or distributed init fails.
"""
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
return
if backend_name is None:
device_target = context.get_context("device_target")


+ 58
- 1
mindspore/context.py View File

@@ -26,9 +26,11 @@ from mindspore._c_expression import MSContext, ms_ctx_param
from mindspore._checkparam import args_type_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context

__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode']
'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 'set_ps_context',
'get_ps_context', 'reset_ps_context']

GRAPH_MODE = 0
PYNATIVE_MODE = 1
@@ -569,3 +571,58 @@ class ParallelMode:
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]

@args_type_check(enable_ps=bool)
def set_ps_context(**kwargs):
"""
Set parameter server training mode context.

Note:
Some other environment variables should also be set for parameter server training mode.
These environment variables are listed below:
MS_SERVER_NUM # Server number
MS_WORKER_NUM # Worker number
MS_SCHED_HOST # Scheduler IP address
MS_SCHED_PORT # Scheduler port
MS_ROLE # The role of this process:
MS_SCHED represents the scheduler,
MS_WORKER represents the worker,
MS_PSERVER represents the Server


Args:
enable_ps (bool): Whether to enable parameter server training mode.
Only after enable_ps is set True, the environment variables will be effective.
Default: False.

Raises:
ValueError: If input key is not the attribute in parameter server training mode context.

Examples:
>>> context.set_ps_context(enable_ps=True)
"""
_set_ps_context(**kwargs)


def get_ps_context(attr_key):
"""
Get parameter server training mode context attribute value according to the key.

Args:
attr_key (str): The key of the attribute.

Returns:
Returns attribute value according to the key.

Raises:
ValueError: If input key is not attribute in auto parallel context.
"""
return _get_ps_context(attr_key)

def reset_ps_context():
"""
Reset parameter server training mode context attributes to the default values:

- enable_ps: False.
"""
_reset_ps_context()

+ 115
- 0
mindspore/parallel/_ps_context.py View File

@@ -0,0 +1,115 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Context for parameter server training mode"""

from mindspore._c_expression import PSContext

_ps_context = None


def ps_context():
"""
Get the global _ps_context, if it is not created, create a new one.

Returns:
_ps_context, the global parameter server training mode context.
"""
global _ps_context
if _ps_context is None:
_ps_context = PSContext.get_instance()
return _ps_context

_set_ps_context_func_map = {
"enable_ps": ps_context().set_ps_enable
}

_get_ps_context_func_map = {
"enable_ps": ps_context().is_ps_enabled
}

def _get_ps_mode_rank():
ps_rank = ps_context().ps_rank_id()
if ps_rank == -1:
raise RuntimeError("The parameter server mode training is not enabled yet.")
return ps_rank

def _set_ps_context(**kwargs):
"""
Set parameter server training mode context.

Note:
Some other environment variables should also be set for parameter server training mode.
These environment variables are listed below:
MS_SERVER_NUM # Server number
MS_WORKER_NUM # Worker number
MS_SCHED_HOST # Scheduler IP address
MS_SCHED_PORT # Scheduler port
MS_ROLE # The role of this process:
MS_SCHED represents the scheduler,
MS_WORKER represents the worker,
MS_PSERVER represents the Server


Args:
enable_ps (bool): Whether to enable parameter server training mode.
Only after enable_ps is set True, the environment variables will be effective.
Default: False.

Raises:
ValueError: If input key is not the attribute in parameter server training mode context.

Examples:
>>> context.set_ps_context(enable_ps=True)
"""
for key, value in kwargs.items():
if key not in _set_ps_context_func_map:
raise ValueError("Set PS context keyword %s is not recognized!" % key)
set_func = _set_ps_context_func_map[key]
set_func(value)

def _get_ps_context(attr_key):
"""
Get parameter server training mode context attribute value according to the key.

Args:
attr_key (str): The key of the attribute.

Returns:
Returns attribute value according to the key.

Raises:
ValueError: If input key is not attribute in auto parallel context.
"""
if key not in _get_ps_context_func_map:
raise ValueError("Get PS context keyword %s is not recognized!" % key)
get_func = _get_ps_context_func_map[attr_key]
get_func(attr_key)

def _reset_ps_context():
"""
Reset parameter server training mode context attributes to the default values:

- enable_ps: False.
"""
ps_context().reset()

def _is_role_worker():
return ps_context().is_role_worker()

def _is_role_pserver():
return ps_context().is_role_pserver()

def _is_role_sched():
return ps_context().is_role_sched()

+ 0
- 23
mindspore/parallel/_ps_utils.py View File

@@ -1,23 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for parameter server training mode"""

from mindspore._c_expression import get_ps_mode_rank

def _get_ps_mode_rank():
ps_rank = get_ps_mode_rank()
if ps_rank == -1:
raise RuntimeError("The parameter server mode training is not launched yet.")
return ps_rank

+ 2
- 2
mindspore/train/callback/_checkpoint.py View File

@@ -24,6 +24,7 @@ from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
from ._callback import Callback, set_cur_net


@@ -280,8 +281,7 @@ class ModelCheckpoint(Callback):
if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt"
if os.getenv("MS_ROLE") == "MS_PSERVER":
from mindspore.parallel._ps_utils import _get_ps_mode_rank
if _is_role_pserver():
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
# update checkpoint file list.
self._manager.update_ckpoint_filelist(self._directory, self._prefix)


+ 3
- 3
mindspore/train/model.py View File

@@ -27,6 +27,7 @@ from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
@@ -378,8 +379,7 @@ class Model:
cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.train_dataset_element = None
cb_params.network = self._network
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
epoch = 1

# build callback list
@@ -516,7 +516,7 @@ class Model:
self._loss_scale_manager.update_loss_scale(overflow)

list_callback.step_end(run_context)
if os.getenv("MS_ROLE") == "MS_PSERVER":
if _is_role_pserver():
os._exit(0)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:


+ 1
- 0
model_zoo/official/cv/resnet/train.py View File

@@ -70,6 +70,7 @@ if __name__ == '__main__':

# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
context.set_ps_context(enable_ps=True)
if args_opt.run_distribute:
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))


+ 0
- 4
model_zoo/official/nlp/bert_thor/src/model_thor.py View File

@@ -14,7 +14,6 @@
# ============================================================================
"""Model."""
import math
import os
from collections.abc import Iterable

import numpy as np
@@ -405,9 +404,6 @@ class Model:
cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.train_dataset_element = None
cb_params.network = self._network
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
epoch = 1

# build callback list
with _CallbackManager(callbacks) as list_callback:


+ 1
- 0
model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py View File

@@ -118,6 +118,7 @@ if __name__ == "__main__":
wide_deep_config.argparse_init()

context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target)
context.set_ps_context(enable_ps=True)
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=get_group_size())


+ 6
- 4
tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py View File

@@ -26,6 +26,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker

parser = argparse.ArgumentParser(description="test_sparse_embedding")
parser.add_argument("--device_target", type=str, default="Ascend")
@@ -34,6 +35,7 @@ device_target = args.device_target
context.set_context(
mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True
)
context.set_ps_context(enable_ps=True)


def fc_with_initialize(input_channels, out_channels):
@@ -81,7 +83,7 @@ def do_sparse_embedding(ps=False):
for _ in range(epoch):
data = Tensor(np.random.randint(0, 15, (32, 3), np.int32))
label = Tensor(np.random.randint(0, 9, (32), np.int32))
if envs.get("MS_ROLE") == "MS_PSERVER":
if _is_role_pserver():
train_network(data, label)
sys.exit()
else:
@@ -96,10 +98,10 @@ if __name__ == "__main__":
np.random.seed(0)
ps_loss = do_sparse_embedding(True)

if envs.get("MS_ROLE") == "MS_WORKER":
envs["MS_ROLE"] = ""
if _is_role_worker():
context.reset_ps_context()
np.random.seed(0)
no_ps_loss = do_sparse_embedding()
envs["MS_ROLE"] = "MS_WORKER"
context.set_ps_context(enable_ps=True)

assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6)

+ 1
- 0
tests/st/ps/full_ps/test_full_ps_lenet.py View File

@@ -35,6 +35,7 @@ args, _ = parser.parse_known_args()
device_target = args.device_target
dataset_path = args.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
context.set_ps_context(enable_ps=True)

def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""


+ 9
- 2
tests/st/ps/multi_full_ps/test_multi_full_ps.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================

import sys
import argparse
import numpy as np

@@ -22,6 +23,7 @@ from mindspore.common.initializer import TruncatedNormal
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.communication.management import init, get_group_size
from mindspore.parallel._ps_context import _is_role_pserver
# from resnet import resnet50

parser = argparse.ArgumentParser(description="test_ps_lenet")
@@ -29,6 +31,7 @@ parser.add_argument("--device_target", type=str, default="Ascend")
args, _ = parser.parse_known_args()
device_target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
context.set_ps_context(enable_ps=True)
if device_target == "GPU":
init()

@@ -106,6 +109,10 @@ if __name__ == "__main__":
for _ in range(epoch):
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
label = Tensor(np.random.randint(0, 9, (32)).astype(np.int32))
loss = train_network(data, label).asnumpy()
losses.append(loss)
if _is_role_pserver():
train_network(data, label)
sys.exit()
else:
loss = train_network(data, label).asnumpy()
losses.append(loss)
print(losses)

Loading…
Cancel
Save