| @@ -0,0 +1,559 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ | |||
| #define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ | |||
| #include <unistd.h> | |||
| #include <unordered_map> | |||
| #include <string> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <mutex> | |||
| #include <condition_variable> | |||
| #include <thread> | |||
| #include <cmath> | |||
| #include <random> | |||
| #include "ir/func_graph.h" | |||
| #include "session/session_basic.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "session/session_factory.h" | |||
| #include "parallel/ps/common.h" | |||
| #include "parallel/ps/optimizer_info.h" | |||
| #include "parallel/ps/optimizer_info_builder.h" | |||
| #include "parallel/ps/util.h" | |||
| #include "device/cpu/kernel_select_cpu.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "kernel/kernel.h" | |||
| #include "kernel/ps/pserver_kernel.h" | |||
| #include "kernel/cpu/cpu_kernel_factory.h" | |||
| #include "kernel/ps/sparse_apply_adam_ps_kernel.h" | |||
| #include "kernel/ps/sparse_apply_ftrl_ps_kernel.h" | |||
| #include "kernel/ps/apply_momentum_ps_kernel.h" | |||
| #include "kernel/ps/embedding_look_up_ps_kernel.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| namespace ps { | |||
| using mindspore::kernel::ps::PServerKernel; | |||
| template <typename T> | |||
| class ParameterServer { | |||
| public: | |||
| static ParameterServer &GetInstance() { | |||
| static ParameterServer instance; | |||
| return instance; | |||
| } | |||
| void Run(const FuncGraphPtr &func_graph); | |||
| private: | |||
| ParameterServer() | |||
| : pserver_num_(0), | |||
| worker_num_(0), | |||
| rank_id_(0), | |||
| grad_accum_count_(0), | |||
| ps_(new ::ps::KVServer<T>(0)), | |||
| handler_(nullptr), | |||
| func_graph_(nullptr), | |||
| kernel_graph_(nullptr), | |||
| sess_(nullptr), | |||
| thread_(nullptr) {} | |||
| ~ParameterServer() = default; | |||
| ParameterServer(const ParameterServer &) = delete; | |||
| ParameterServer &operator=(const ParameterServer &) = delete; | |||
| struct ServerHandler { | |||
| explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} | |||
| void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVServer<T> *server); | |||
| void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data); | |||
| void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | |||
| void HandleInitWeights(const ::ps::KVPairs<T> &req_data); | |||
| void HandleInitWeightToOptimId(const ::ps::KVPairs<T> &req_data); | |||
| void HandleInitInputsShape(const ::ps::KVPairs<T> &req_data); | |||
| void HandleInitEmbeddings(const ::ps::KVPairs<T> &req_data); | |||
| void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | |||
| ParameterServer *ps_; | |||
| }; | |||
| bool Init(const FuncGraphPtr &func_graph); | |||
| void InitOptimInfoBuilders(); | |||
| void InitWeightKeyToOptims(const Key &key, const int &optim_id); | |||
| void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); | |||
| void InitWeight(const Key &key, const WeightPtr &weight); | |||
| void InitGrad(const Key &key, const GradPtr &grad); | |||
| void InitEmbeddingTable(const Key &key, | |||
| const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes); | |||
| void UpdateWeights(); | |||
| void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); | |||
| WeightPtr weight(const Key &key); | |||
| void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res); | |||
| int SumOfShapes(const std::vector<int> &shapes) const; | |||
| size_t PreComputeCapacity(const Keys &keys, const Lengths &lens); | |||
| bool ReadyForUpdateWeights(); | |||
| bool ReadyForAccumGrads(); | |||
| void ResetGradAccumCount(); | |||
| size_t pserver_num_; | |||
| size_t worker_num_; | |||
| size_t rank_id_; | |||
| size_t grad_accum_count_; | |||
| std::unique_ptr<::ps::KVServer<T>> ps_; | |||
| std::unique_ptr<ServerHandler> handler_; | |||
| FuncGraphPtr func_graph_; | |||
| std::shared_ptr<session::KernelGraph> kernel_graph_; | |||
| std::shared_ptr<session::SessionBasic> sess_; | |||
| std::unordered_map<std::string, std::shared_ptr<PServerKernel>> optimizers_; | |||
| std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_; | |||
| std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_; | |||
| std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_; | |||
| std::unordered_map<Key, std::string> weight_key_to_optims_; | |||
| std::unordered_map<Key, WeightPtr> weights_; | |||
| std::unordered_map<Key, WeightPtr> grads_; | |||
| std::unordered_map<Key, size_t> grads_accum_counter_; | |||
| // std::unordered_map<Key, EmbeddingTablePtr> embeddings_; | |||
| std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_; | |||
| std::unordered_map<Key, size_t> embedding_row_lens_; | |||
| T learning_rate_; | |||
| T momentum_; | |||
| std::mutex mutex_; | |||
| std::condition_variable apply_grads_cv_; | |||
| std::condition_variable accum_grads_cv_; | |||
| std::unique_ptr<std::thread> thread_; | |||
| friend struct ServerHandler; | |||
| }; | |||
| class FuncGraph; | |||
| template <typename T> | |||
| void ParameterServer<T>::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, | |||
| ::ps::KVServer<T> *server) { | |||
| ::ps::KVPairs<T> res; | |||
| if (req_meta.cmd == kInitWeightsCmd) { | |||
| MS_LOG(ERROR) << "handle init weights cmd" << std::endl; | |||
| HandleInitWeights(req_data); | |||
| } else if (req_meta.cmd == kInitWeightToOptimIdCmd) { | |||
| MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl; | |||
| HandleInitWeightToOptimId(req_data); | |||
| } else if (req_meta.cmd == kInitOptimInputsShapeCmd) { | |||
| MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl; | |||
| HandleInitInputsShape(req_data); | |||
| } else if (req_meta.cmd == kInitEmbeddingsCmd) { | |||
| MS_LOG(ERROR) << "handle init embedding cmd" << std::endl; | |||
| HandleInitEmbeddings(req_data); | |||
| } else if (req_meta.cmd == kEmbeddingLookupCmd) { | |||
| MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl; | |||
| HandleEmbeddingLookup(req_meta, req_data, &res); | |||
| } else if (req_meta.push) { | |||
| MS_LOG(ERROR) << "handle push req cmd" << std::endl; | |||
| HandlePushReq(req_meta, req_data); | |||
| } else { | |||
| MS_LOG(ERROR) << "handle pull req cmd" << std::endl; | |||
| HandlePullReq(req_meta, req_data, &res); | |||
| } | |||
| server->Response(req_meta, res); | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data) { | |||
| ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens); | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, | |||
| ::ps::KVPairs<T> *res) { | |||
| res->keys = req_data.keys; | |||
| ::ps::Key key = req_data.keys[0]; | |||
| res->vals = *(ps_->weight(key)); | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVPairs<T> &req_data) { | |||
| size_t key_num = req_data.keys.size(); | |||
| T *data_ptr = req_data.vals.data(); | |||
| size_t pos = 0; | |||
| for (size_t i = 0; i < key_num; i++) { | |||
| Key key = req_data.keys[i]; | |||
| size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; | |||
| WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>(); | |||
| weight_ptr->CopyFrom(data_ptr + pos, data_len); | |||
| ps_->InitWeight(key, weight_ptr); | |||
| GradPtr grad_ptr = std::make_shared<::ps::SArray<T>>(data_len, 0); | |||
| ps_->InitGrad(key, grad_ptr); | |||
| pos += data_len; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs<T> &req_data) { | |||
| size_t key_num = req_data.keys.size(); | |||
| for (size_t i = 0; i < key_num; i++) { | |||
| Key key = req_data.keys[i]; | |||
| T val = req_data.vals[i]; | |||
| ps_->InitWeightKeyToOptims(key, val); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs<T> &req_data) { | |||
| ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens); | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs<T> &req_data) { | |||
| std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes = | |||
| std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>(); | |||
| std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>(); | |||
| std::shared_ptr<std::vector<size_t>> indices_shape = std::make_shared<std::vector<size_t>>(); | |||
| std::shared_ptr<std::vector<size_t>> output_shape = std::make_shared<std::vector<size_t>>(); | |||
| shapes->push_back(input_shape); | |||
| shapes->push_back(indices_shape); | |||
| shapes->push_back(output_shape); | |||
| const Key &key = req_data.keys[0]; | |||
| const Lengths &lens = req_data.lens; | |||
| size_t index = 0; | |||
| for (int i = 0; i < lens[0]; i++) { | |||
| input_shape->push_back(static_cast<size_t>(req_data.vals[index++])); | |||
| } | |||
| for (int j = 0; j < lens[1]; j++) { | |||
| indices_shape->push_back(static_cast<size_t>(req_data.vals[index++])); | |||
| } | |||
| for (int k = 0; k < lens[2]; k++) { | |||
| output_shape->push_back(static_cast<size_t>(req_data.vals[index++])); | |||
| } | |||
| ps_->InitEmbeddingTable(key, shapes); | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, | |||
| const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) { | |||
| const Key &key = req_data.keys[0]; | |||
| ps_->DoEmbeddingLookup(key, req_data.vals, res); | |||
| for (size_t i = 0; i < req_data.vals.size(); i++) { | |||
| res->keys->push_back(req_data.vals[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) { | |||
| const char *server_num = getenv(kEnvPServerNum); | |||
| const char *worker_num = getenv(kEnvWorkerNum); | |||
| if (server_num != nullptr) { | |||
| pserver_num_ = *server_num - '0'; | |||
| } | |||
| if (worker_num != nullptr) { | |||
| worker_num_ = *worker_num - '0'; | |||
| } | |||
| func_graph_ = func_graph; | |||
| rank_id_ = ::ps::MyRank(); | |||
| handler_.reset(new ServerHandler(this)); | |||
| InitOptimInfoBuilders(); | |||
| ps_->set_request_handle(*handler_); | |||
| thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::InitOptimInfoBuilders() { | |||
| std::shared_ptr<OptimizerInfoBuilder> momentum_info_builder = std::make_shared<MomentumOptimInfoBuilder>(); | |||
| std::shared_ptr<OptimizerInfoBuilder> sparse_adam_info_builder = std::make_shared<SparseAdamOptimInfoBuilder>(); | |||
| std::shared_ptr<OptimizerInfoBuilder> sparse_ftrl_info_builder = std::make_shared<SparseFtrlOptimInfoBuilder>(); | |||
| optim_info_builders_[kApplyMomentum] = momentum_info_builder; | |||
| optim_info_builders_[kSparseAdam] = sparse_adam_info_builder; | |||
| optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder; | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::InitWeightKeyToOptims(const Key &key, const int &optim_id) { | |||
| if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(key) == "") { | |||
| return; | |||
| } | |||
| weight_key_to_optims_[key] = Util::optimizer_name(optim_id); | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) { | |||
| InputsShapePtr inputs_shape = std::make_shared<InputsShape>(); | |||
| int val_idx = 0; | |||
| const Key &key = keys[0]; | |||
| if (optim_inputs_shape_.count(key) == 0) { | |||
| optim_inputs_shape_[key] = inputs_shape; | |||
| } | |||
| for (size_t i = 0; i < keys.size(); i++) { | |||
| auto shape = std::make_shared<std::vector<size_t>>(); | |||
| inputs_shape->push_back(shape); | |||
| int len = lengths[i]; | |||
| for (int j = 0; j < len; j++) { | |||
| shape->push_back(values[val_idx++]); | |||
| } | |||
| } | |||
| if (weight_key_to_optims_.count(key) > 0) { | |||
| const std::string &optim_name = weight_key_to_optims_[key]; | |||
| if (optimizers_.count(optim_name) == 0 && optim_inputs_shape_.count(key) > 0) { | |||
| if (optim_name == kSparseAdam) { | |||
| std::shared_ptr<PServerKernel> optimizer = | |||
| std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_); | |||
| optimizer->InitKernel(optim_inputs_shape_[key]); | |||
| optimizers_[optim_name] = optimizer; | |||
| } else if (optim_name == kApplyMomentum) { | |||
| std::shared_ptr<PServerKernel> optimizer = | |||
| std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_); | |||
| optimizer->InitKernel(optim_inputs_shape_[key]); | |||
| optimizers_[optim_name] = optimizer; | |||
| } else if (optim_name == kSparseFtrl) { | |||
| std::shared_ptr<PServerKernel> optimizer = | |||
| std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_); | |||
| optimizer->InitKernel(optim_inputs_shape_[key]); | |||
| optimizers_[optim_name] = optimizer; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) { | |||
| if (weights_.count(key) == 0) { | |||
| weights_[key] = weight; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) { | |||
| if (grads_.count(key) == 0) { | |||
| grads_[key] = grad; | |||
| grads_accum_counter_[key] = 0; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::InitEmbeddingTable( | |||
| const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | |||
| // Init embedding lookup kernel | |||
| std::shared_ptr<PServerKernel> lookup = std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_); | |||
| lookup->InitKernel(shapes); | |||
| embedding_lookup_ops_[key] = lookup; | |||
| // Init embedding weight | |||
| const std::vector<size_t> &input_shapes = lookup->input_sizes(); | |||
| size_t total_dims = 1; | |||
| for (auto shape : input_shapes) { | |||
| total_dims *= shape; | |||
| } | |||
| WeightPtr embedding = std::make_shared<Weight>(total_dims, 0.01); | |||
| weights_[key] = embedding; | |||
| grads_accum_counter_[key] = 0; | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::UpdateWeights() { | |||
| while (true) { | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); }); | |||
| for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { | |||
| Key key = iter->first; | |||
| WeightPtr weight_ptr = iter->second; | |||
| std::shared_ptr<PServerKernel> optimizer = nullptr; | |||
| if (weight_key_to_optims_.count(key) > 0) { | |||
| const std::string &optim_name = weight_key_to_optims_[key]; | |||
| optimizer = optimizers_[optim_name]; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(optimizer); | |||
| std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key]; | |||
| if (optim_info == nullptr) { | |||
| continue; | |||
| } | |||
| const WeightPtr &weight = weights_[key]; | |||
| optim_info->UpdateWeight(weight); | |||
| const std::vector<kernel::AddressPtr> &inputs = optim_info->inputs(); | |||
| const std::vector<kernel::AddressPtr> &workspaces = optim_info->workspaces(); | |||
| const std::vector<kernel::AddressPtr> &outputs = optim_info->outputs(); | |||
| optimizer->Execute(inputs, workspaces, outputs); | |||
| optim_info->Reset(); | |||
| } | |||
| ResetGradAccumCount(); | |||
| accum_grads_cv_.notify_all(); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) { | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); }); | |||
| const Key &key = keys[0]; | |||
| std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key]; | |||
| // Create or update the optimizer info | |||
| if (optim_info == nullptr) { | |||
| const std::shared_ptr<OptimizerInfoBuilder> &builder = optim_info_builders_[weight_key_to_optims_[key]]; | |||
| std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[weight_key_to_optims_[key]]; | |||
| if (pserver_kernel == nullptr) { | |||
| MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(pserver_kernel); | |||
| OptimizerInfo *optim = | |||
| builder->Build(pserver_kernel, weights_[key], keys, values, lengths, optim_inputs_shape_[key], worker_num_); | |||
| optim_info.reset(optim); | |||
| optim_infos_[key] = optim_info; | |||
| } else { | |||
| optim_info->Update(values, lengths); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(optim_info); | |||
| optim_info->Accumulate(values, lengths); | |||
| grads_accum_counter_[key] += 1; | |||
| if (grads_accum_counter_[key] == worker_num_) { | |||
| grad_accum_count_++; | |||
| } | |||
| if (ReadyForUpdateWeights()) { | |||
| apply_grads_cv_.notify_one(); | |||
| } | |||
| } | |||
| template <typename T> | |||
| WeightPtr ParameterServer<T>::weight(const Key &key) { | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| if (weights_.count(key) == 0) { | |||
| MS_LOG(ERROR) << "Invalid weight key " << key; | |||
| return nullptr; | |||
| } | |||
| WeightPtr weight_ptr = weights_[key]; | |||
| WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray<T>>(weight_ptr->size(), 0); | |||
| copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size()); | |||
| return copy_weight_ptr; | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res) { | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| if (weights_.count(key) == 0) { | |||
| MS_LOG(ERROR) << "Invalid embedding table key " << key; | |||
| return; | |||
| } | |||
| if (embedding_lookup_ops_.count(key) == 0) { | |||
| MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; | |||
| return; | |||
| } | |||
| WeightPtr table_ptr = weights_[key]; | |||
| std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key]; | |||
| // Update shapes of lookup operator | |||
| std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes = | |||
| std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>(); | |||
| std::shared_ptr<std::vector<size_t>> indices_shape = std::make_shared<std::vector<size_t>>(); | |||
| indices_shape->emplace_back(lookup_ids.size()); | |||
| shapes->push_back(indices_shape); | |||
| table_lookup_op->ReInit(shapes); | |||
| const std::vector<size_t> output_shapes = table_lookup_op->output_sizes(); | |||
| std::vector<kernel::AddressPtr> inputs; | |||
| AddressPtr embedding_table = std::make_shared<kernel::Address>(); | |||
| AddressPtr indices = std::make_shared<kernel::Address>(); | |||
| inputs.push_back(embedding_table); | |||
| inputs.push_back(indices); | |||
| embedding_table->addr = table_ptr->data(); | |||
| embedding_table->size = table_ptr->size() * sizeof(T); | |||
| indices->addr = lookup_ids.data(); | |||
| indices->size = lookup_ids.size() * sizeof(T); | |||
| std::vector<kernel::AddressPtr> workspaces; | |||
| std::vector<kernel::AddressPtr> outputs; | |||
| AddressPtr output = std::make_shared<kernel::Address>(); | |||
| std::shared_ptr<Values> addr = std::make_shared<Values>(output_shapes[0] / sizeof(T), 0); | |||
| output->addr = addr->data(); | |||
| output->size = output_shapes[0]; | |||
| outputs.push_back(output); | |||
| table_lookup_op->Execute(inputs, workspaces, outputs); | |||
| res->vals = *addr; | |||
| res->lens.push_back(res.vals.size()); | |||
| } | |||
| template <typename T> | |||
| int ParameterServer<T>::SumOfShapes(const std::vector<int> &shapes) const { | |||
| int sum = 1; | |||
| for (auto shape : shapes) { | |||
| sum *= shape; | |||
| } | |||
| return sum; | |||
| } | |||
| template <typename T> | |||
| size_t ParameterServer<T>::PreComputeCapacity(const Keys &keys, const Lengths &lens) { | |||
| size_t capacity = 0; | |||
| for (size_t i = 0; i < keys.size(); i++) { | |||
| Key key = keys[i]; | |||
| if (embedding_row_lens_.count(key) > 0) { | |||
| capacity += embedding_row_lens_[key] * lens[i]; | |||
| } else { | |||
| MS_LOG(ERROR) << "Invalid embedding lookup id " << key; | |||
| } | |||
| } | |||
| return capacity; | |||
| } | |||
| template <typename T> | |||
| inline bool ParameterServer<T>::ReadyForUpdateWeights() { | |||
| return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); | |||
| } | |||
| template <typename T> | |||
| inline bool ParameterServer<T>::ReadyForAccumGrads() { | |||
| return grad_accum_count_ < weights_.size(); | |||
| } | |||
| template <typename T> | |||
| inline void ParameterServer<T>::ResetGradAccumCount() { | |||
| grad_accum_count_ = 0; | |||
| for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) { | |||
| grads_accum_counter_[iter->first] = 0; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) { | |||
| ::ps::Start(0); | |||
| if (!::ps::IsServer()) { | |||
| std::cout << "This is not ther Server" << std::endl; | |||
| return; | |||
| } | |||
| Init(func_graph); | |||
| thread_->join(); | |||
| } | |||
| } // namespace ps | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ | |||