|
|
|
@@ -47,7 +47,8 @@ class Worker { |
|
|
|
void SetOptimInputShapes(size_t key, const std::vector<int> &shape); |
|
|
|
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); |
|
|
|
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const std::vector<int> &sizes); |
|
|
|
void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size); |
|
|
|
void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size, |
|
|
|
bool init_in_server = false); |
|
|
|
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, |
|
|
|
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd); |
|
|
|
void Finalize(); |
|
|
|
@@ -240,7 +241,8 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
// Initialize parameters and optimizer kernels of Parameter Server. |
|
|
|
void Worker<T>::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size) { |
|
|
|
void Worker<T>::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size, |
|
|
|
bool init_in_server) { |
|
|
|
size_t param_key = GetParamKey(param_name); |
|
|
|
if (param_key == kInvalidKey) { |
|
|
|
MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned."; |
|
|
|
@@ -248,9 +250,9 @@ void Worker<T>::InitPSParamAndOptim(const std::string ¶m_name, void *param_d |
|
|
|
} |
|
|
|
bool init = IsKeyInit(param_key); |
|
|
|
if (!init) { |
|
|
|
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name; |
|
|
|
// No need to push embedding table data to Parameter Server. |
|
|
|
if (param_name.find("embedding_table") == std::string::npos && param_name.find("wide_w") == std::string::npos) { |
|
|
|
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name |
|
|
|
<< ", whether init in server: " << init_in_server; |
|
|
|
if (!init_in_server) { |
|
|
|
InitPSParamData({param_key}, param_data, param_size); |
|
|
|
} |
|
|
|
InitPSOptimId(param_key); |
|
|
|
|