From: @zyli2020 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -1085,10 +1085,10 @@ void ClearResAtexit() { | |||||
| session::ClearPythonParasMap(); | session::ClearPythonParasMap(); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) { | if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) { | ||||
| ps::worker.Finalize(); | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | ||||
| ps::ps_cache_instance.Finalize(); | ps::ps_cache_instance.Finalize(); | ||||
| } | } | ||||
| ps::worker.Finalize(); | |||||
| } | } | ||||
| #endif | #endif | ||||
| ad::g_k_prims.clear(); | ad::g_k_prims.clear(); | ||||
| @@ -552,7 +552,6 @@ template <typename T> | |||||
| void ParameterServer<T>::Finalize() { | void ParameterServer<T>::Finalize() { | ||||
| running_ = false; | running_ = false; | ||||
| apply_grads_cv_.notify_one(); | apply_grads_cv_.notify_one(); | ||||
| SyncEmbeddingTables(); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| @@ -774,7 +773,7 @@ void ParameterServer<T>::GetEmbeddingTableParamPtr() { | |||||
| for (auto cnode : cnodes) { | for (auto cnode : cnodes) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| std::string cnode_name = AnfAlgo::GetCNodeName(cnode); | std::string cnode_name = AnfAlgo::GetCNodeName(cnode); | ||||
| if (cnode_name == kEmbeddingLookupOpName) { | |||||
| if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName) { | |||||
| auto embedding_table = AnfAlgo::GetInputNode(cnode, 0); | auto embedding_table = AnfAlgo::GetInputNode(cnode, 0); | ||||
| MS_EXCEPTION_IF_NULL(embedding_table); | MS_EXCEPTION_IF_NULL(embedding_table); | ||||
| MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; | MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; | ||||
| @@ -832,6 +831,7 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) { | |||||
| Init(func_graph); | Init(func_graph); | ||||
| PSContext::instance()->SetPSRankId(rank_id_); | PSContext::instance()->SetPSRankId(rank_id_); | ||||
| thread_->join(); | thread_->join(); | ||||
| SyncEmbeddingTables(); | |||||
| MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; | MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; | ||||
| ::ps::Finalize(0, true); | ::ps::Finalize(0, true); | ||||
| MS_LOG(INFO) << "PServer finalized successfully."; | MS_LOG(INFO) << "PServer finalized successfully."; | ||||
| @@ -30,21 +30,21 @@ int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out | |||||
| if (loop++ == hash_capacity_) { | if (loop++ == hash_capacity_) { | ||||
| return INVALID_INDEX_VALUE; | return INVALID_INDEX_VALUE; | ||||
| } | } | ||||
| if (hash_map_unit_[hash_index].IsEmpty()) { | |||||
| if (hash_map_elements_[hash_index].IsEmpty()) { | |||||
| hash_count_++; | hash_count_++; | ||||
| (void)hash_id_to_index_.emplace(id, hash_index); | (void)hash_id_to_index_.emplace(id, hash_index); | ||||
| hash_map_unit_[hash_index].set_id(id); | |||||
| hash_map_unit_[hash_index].set_step(data_step); | |||||
| hash_map_elements_[hash_index].set_id(id); | |||||
| hash_map_elements_[hash_index].set_step(data_step); | |||||
| return hash_index; | return hash_index; | ||||
| } else if (need_swap && hash_map_unit_[hash_index].IsExpired(graph_running_step)) { | |||||
| } else if (need_swap && hash_map_elements_[hash_index].IsExpired(graph_running_step)) { | |||||
| // Need swap out from the hash table. | // Need swap out from the hash table. | ||||
| swap_out_index[*swap_out_size] = hash_index; | swap_out_index[*swap_out_size] = hash_index; | ||||
| swap_out_ids[*swap_out_size] = hash_map_unit_[hash_index].id_; | |||||
| swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_; | |||||
| (*swap_out_size)++; | (*swap_out_size)++; | ||||
| (void)hash_id_to_index_.erase(hash_map_unit_[hash_index].id_); | |||||
| (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_); | |||||
| (void)hash_id_to_index_.emplace(id, hash_index); | (void)hash_id_to_index_.emplace(id, hash_index); | ||||
| hash_map_unit_[hash_index].set_id(id); | |||||
| hash_map_unit_[hash_index].set_step(data_step); | |||||
| hash_map_elements_[hash_index].set_id(id); | |||||
| hash_map_elements_[hash_index].set_step(data_step); | |||||
| return hash_index; | return hash_index; | ||||
| } | } | ||||
| hash_index = (hash_index + 1) % hash_capacity_; | hash_index = (hash_index + 1) % hash_capacity_; | ||||
| @@ -58,9 +58,10 @@ void EmbeddingHashMap::DumpHashMap() { | |||||
| MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second; | MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Dump hash_map_unit: "; | MS_LOG(INFO) << "Dump hash_map_unit: "; | ||||
| for (size_t i = 0; i < hash_map_unit_.size(); i++) { | |||||
| if (!hash_map_unit_[i].IsEmpty()) { | |||||
| MS_LOG(INFO) << " index: " << i << " id: " << hash_map_unit_[i].id_ << " step: " << hash_map_unit_[i].step_; | |||||
| for (size_t i = 0; i < hash_map_elements_.size(); i++) { | |||||
| if (!hash_map_elements_[i].IsEmpty()) { | |||||
| MS_LOG(INFO) << " index: " << i << " id: " << hash_map_elements_[i].id_ | |||||
| << " step: " << hash_map_elements_[i].step_; | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(INFO) << "Dump hash map info end."; | MS_LOG(INFO) << "Dump hash map info end."; | ||||
| @@ -30,8 +30,8 @@ static const size_t INVALID_STEP_VALUE = 0; | |||||
| static const int INVALID_INDEX_VALUE = -1; | static const int INVALID_INDEX_VALUE = -1; | ||||
| struct HashMapElement { | struct HashMapElement { | ||||
| int id_; | |||||
| size_t step_; | |||||
| int id_{INVALID_INDEX_VALUE}; | |||||
| size_t step_{INVALID_STEP_VALUE}; | |||||
| bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } | bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } | ||||
| bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } | bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } | ||||
| void set_id(int id) { id_ = id; } | void set_id(int id) { id_ = id; } | ||||
| @@ -42,7 +42,7 @@ struct HashMapElement { | |||||
| class EmbeddingHashMap { | class EmbeddingHashMap { | ||||
| public: | public: | ||||
| EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) { | EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) { | ||||
| hash_map_unit_.resize(hash_capacity); | |||||
| hash_map_elements_.resize(hash_capacity); | |||||
| } | } | ||||
| virtual ~EmbeddingHashMap() = default; | virtual ~EmbeddingHashMap() = default; | ||||
| int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | ||||
| @@ -51,8 +51,10 @@ class EmbeddingHashMap { | |||||
| bool IsIdExist(const std::unordered_map<int, int>::const_iterator iter) const { | bool IsIdExist(const std::unordered_map<int, int>::const_iterator iter) const { | ||||
| return iter != hash_id_to_index_.end(); | return iter != hash_id_to_index_.end(); | ||||
| } | } | ||||
| size_t hash_step(const int hash_index) const { return hash_map_unit_[hash_index].step_; } | |||||
| void set_hash_step(const int hash_index, const size_t step) { hash_map_unit_[hash_index].set_step(step); } | |||||
| size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; } | |||||
| void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); } | |||||
| const std::unordered_map<int, int> &hash_id_to_index() const { return hash_id_to_index_; } | |||||
| size_t hash_capacity() const { return hash_capacity_; } | |||||
| void DumpHashMap(); | void DumpHashMap(); | ||||
| private: | private: | ||||
| @@ -60,7 +62,7 @@ class EmbeddingHashMap { | |||||
| bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); } | bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); } | ||||
| size_t hash_count_; | size_t hash_count_; | ||||
| size_t hash_capacity_; | size_t hash_capacity_; | ||||
| std::vector<HashMapElement> hash_map_unit_; | |||||
| std::vector<HashMapElement> hash_map_elements_; | |||||
| std::unordered_map<int, int> hash_id_to_index_; | std::unordered_map<int, int> hash_id_to_index_; | ||||
| }; | }; | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -226,9 +226,9 @@ void PsCacheManager::AllocMemForHashTable() { | |||||
| device_address.addr = addr; | device_address.addr = addr; | ||||
| auto &host_address = item.second.host_address; | auto &host_address = item.second.host_address; | ||||
| auto host_address_ptr = new int[host_cache_vocab_size_ * embedding_size]; | |||||
| auto host_address_ptr = new float[host_cache_vocab_size_ * embedding_size]; | |||||
| MS_EXCEPTION_IF_NULL(host_address_ptr); | MS_EXCEPTION_IF_NULL(host_address_ptr); | ||||
| host_address = std::shared_ptr<int[]>(host_address_ptr, std::default_delete<int[]>()); | |||||
| host_address = std::shared_ptr<float[]>(host_address_ptr, std::default_delete<float[]>()); | |||||
| MS_EXCEPTION_IF_NULL(host_address); | MS_EXCEPTION_IF_NULL(host_address); | ||||
| max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size; | max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size; | ||||
| @@ -330,6 +330,14 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { | |||||
| } | } | ||||
| void PsCacheManager::Finalize() { | void PsCacheManager::Finalize() { | ||||
| if (running_) { | |||||
| if (!SyncHostEmbeddingTable()) { | |||||
| MS_LOG(ERROR) << "SyncHostEmbeddingTable failed."; | |||||
| } | |||||
| if (!SyncDeviceEmbeddingTable()) { | |||||
| MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed."; | |||||
| } | |||||
| } | |||||
| running_ = false; | running_ = false; | ||||
| PsDataPrefetch::GetInstance().NotifyFinalize(); | PsDataPrefetch::GetInstance().NotifyFinalize(); | ||||
| insert_init_info_.notify_all(); | insert_init_info_.notify_all(); | ||||
| @@ -838,6 +846,99 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool PsCacheManager::SyncHostEmbeddingTable() { | |||||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||||
| const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index(); | |||||
| size_t swap_indices_lens = hash_id_to_index.size(); | |||||
| if (swap_indices_lens == 0) { | |||||
| return true; | |||||
| } | |||||
| std::unique_ptr<int[]> host_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens); | |||||
| MS_ERROR_IF_NULL(host_to_server_ids_ptr); | |||||
| std::unique_ptr<int[]> host_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens); | |||||
| MS_ERROR_IF_NULL(host_to_server_indices_ptr); | |||||
| size_t idx = 0; | |||||
| for (const auto &item : hash_id_to_index) { | |||||
| host_to_server_ids_ptr[idx] = item.first; | |||||
| host_to_server_indices_ptr[idx++] = item.second; | |||||
| } | |||||
| for (const auto &item : hash_tables_) { | |||||
| const auto &hash_info = item.second; | |||||
| if (hash_info.param_init_info_.param_type_ != kWeight) { | |||||
| continue; | |||||
| } | |||||
| auto key = worker.GetParamKey(item.first); | |||||
| ::ps::SArray<int> lookup_ids(swap_indices_lens, 0); | |||||
| ::ps::SArray<float> swap_out_data; | |||||
| auto embedding_size = hash_info.embedding_size; | |||||
| swap_out_data.resize(swap_indices_lens * embedding_size); | |||||
| auto host_hash_table_addr = hash_info.host_address.get(); | |||||
| MS_ERROR_IF_NULL(host_hash_table_addr); | |||||
| RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, host_hash_table_addr, | |||||
| host_to_server_indices_ptr.get(), swap_out_data.data())); | |||||
| auto copy_len = swap_indices_lens * sizeof(int); | |||||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids_ptr.get(), copy_len); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||||
| return false; | |||||
| } | |||||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool PsCacheManager::SyncDeviceEmbeddingTable() { | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||||
| const auto &device_hash_map = embedding_device_cache_->device_hash_map_; | |||||
| const auto &hash_id_to_index = device_hash_map->hash_id_to_index(); | |||||
| size_t swap_indices_lens = hash_id_to_index.size(); | |||||
| if (swap_indices_lens == 0) { | |||||
| return true; | |||||
| } | |||||
| std::unique_ptr<int[]> device_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens); | |||||
| MS_ERROR_IF_NULL(device_to_server_ids_ptr); | |||||
| std::unique_ptr<int[]> device_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens); | |||||
| MS_ERROR_IF_NULL(device_to_server_indices_ptr); | |||||
| size_t idx = 0; | |||||
| for (const auto &item : hash_id_to_index) { | |||||
| device_to_server_ids_ptr[idx] = item.first; | |||||
| device_to_server_indices_ptr[idx++] = item.second; | |||||
| } | |||||
| for (const auto &item : hash_tables_) { | |||||
| const auto &hash_info = item.second; | |||||
| if (hash_info.param_init_info_.param_type_ != kWeight) { | |||||
| continue; | |||||
| } | |||||
| auto key = worker.GetParamKey(item.first); | |||||
| ::ps::SArray<int> lookup_ids(swap_indices_lens, 0); | |||||
| ::ps::SArray<float> swap_out_data; | |||||
| auto embedding_size = hash_info.embedding_size; | |||||
| swap_out_data.resize(swap_indices_lens * embedding_size); | |||||
| std::unique_ptr<float[]> device_hash_table_addr_tmp = | |||||
| std::make_unique<float[]>(device_hash_map->hash_capacity() * embedding_size); | |||||
| MS_ERROR_IF_NULL(device_hash_table_addr_tmp); | |||||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||||
| MS_ERROR_IF_NULL(hash_table_addr); | |||||
| auto hash_table_size = hash_info.device_address.size; | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(device_hash_table_addr_tmp.get(), | |||||
| hash_table_addr, hash_table_size)); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream()); | |||||
| RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, device_hash_table_addr_tmp.get(), | |||||
| device_to_server_indices_ptr.get(), swap_out_data.data())); | |||||
| auto copy_len = swap_indices_lens * sizeof(int); | |||||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, device_to_server_ids_ptr.get(), copy_len); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||||
| return false; | |||||
| } | |||||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void PsCacheManager::DumpHashTables(bool dump_device_tables) const { | void PsCacheManager::DumpHashTables(bool dump_device_tables) const { | ||||
| for (const auto &item : hash_tables_) { | for (const auto &item : hash_tables_) { | ||||
| const auto ¶m_name = item.first; | const auto ¶m_name = item.first; | ||||
| @@ -48,7 +48,7 @@ struct HashTableInfo { | |||||
| size_t embedding_size{0}; | size_t embedding_size{0}; | ||||
| size_t vocab_size{0}; | size_t vocab_size{0}; | ||||
| Address device_address{nullptr, 0}; | Address device_address{nullptr, 0}; | ||||
| std::shared_ptr<int[]> host_address{nullptr}; | |||||
| std::shared_ptr<float[]> host_address{nullptr}; | |||||
| ParamInitInfo param_init_info_; | ParamInitInfo param_init_info_; | ||||
| }; | }; | ||||
| @@ -166,6 +166,8 @@ class PsCacheManager { | |||||
| bool CheckFinishInsertInitInfo() const; | bool CheckFinishInsertInitInfo() const; | ||||
| void AddEmbeddingTable() const; | void AddEmbeddingTable() const; | ||||
| void DumpStatisticsInfo(size_t each_print_step = 1000); | void DumpStatisticsInfo(size_t each_print_step = 1000); | ||||
| bool SyncHostEmbeddingTable(); | |||||
| bool SyncDeviceEmbeddingTable(); | |||||
| bool initialized_ps_cache_{false}; | bool initialized_ps_cache_{false}; | ||||
| std::string channel_name_; | std::string channel_name_; | ||||
| @@ -205,6 +205,7 @@ constexpr auto kPushOpName = "Push"; | |||||
| constexpr auto kPullOpName = "Pull"; | constexpr auto kPullOpName = "Pull"; | ||||
| constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; | constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; | ||||
| constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; | constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; | ||||
| constexpr auto kGatherV2OpName = "GatherV2"; | |||||
| constexpr auto kPaddingOpName = "Padding"; | constexpr auto kPaddingOpName = "Padding"; | ||||
| constexpr auto kAvgPoolOpName = "AvgPool"; | constexpr auto kAvgPoolOpName = "AvgPool"; | ||||
| constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | ||||
| @@ -292,7 +292,8 @@ class EmbeddingLookup(Cell): | |||||
| "in 'full_batch' and 'table_row_slice' parallel strategy.") | "in 'full_batch' and 'table_row_slice' parallel strategy.") | ||||
| self.vocab_cache_size = self.vocab_cache_size * device_num | self.vocab_cache_size = self.vocab_cache_size * device_num | ||||
| self.cache_enable = True | self.cache_enable = True | ||||
| self.vocab_size = self.vocab_cache_size | |||||
| if _is_role_worker(): | |||||
| self.vocab_size = self.vocab_cache_size | |||||
| def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size): | def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size): | ||||
| """PS embeddingLookup cache enable set.""" | """PS embeddingLookup cache enable set.""" | ||||
| @@ -24,6 +24,7 @@ from mindspore.context import ParallelMode | |||||
| from mindspore.communication.management import get_rank, get_group_size, init | from mindspore.communication.management import get_rank, get_group_size, init | ||||
| from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple | from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple | ||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from mindspore.parallel._ps_context import _is_role_worker | |||||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | ||||
| from src.callbacks import LossCallBack, EvalCallBack | from src.callbacks import LossCallBack, EvalCallBack | ||||
| @@ -117,11 +118,14 @@ def train_and_eval(config): | |||||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | ||||
| callback = LossCallBack(config=config) | callback = LossCallBack(config=config) | ||||
| if cache_enable: | |||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, | |||||
| keep_checkpoint_max=5, integrated_save=False) | |||||
| if _is_role_worker(): | |||||
| if cache_enable: | |||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, | |||||
| keep_checkpoint_max=1, integrated_save=False) | |||||
| else: | |||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | |||||
| else: | else: | ||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | |||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1) | |||||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | ||||
| directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', | directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', | ||||
| config=ckptconfig) | config=ckptconfig) | ||||
| @@ -20,6 +20,7 @@ import sys | |||||
| from mindspore import Model, context | from mindspore import Model, context | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | ||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from mindspore.parallel._ps_context import _is_role_worker | |||||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | ||||
| from src.callbacks import LossCallBack, EvalCallBack | from src.callbacks import LossCallBack, EvalCallBack | ||||
| @@ -99,7 +100,14 @@ def train_and_eval(config): | |||||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | ||||
| callback = LossCallBack(config=config) | callback = LossCallBack(config=config) | ||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | |||||
| if _is_role_worker(): | |||||
| if cache_enable: | |||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size() * epochs, | |||||
| keep_checkpoint_max=1) | |||||
| else: | |||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | |||||
| else: | |||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1) | |||||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) | ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) | ||||
| callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb] | callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb] | ||||