diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc index 7f3e280ece..bf4fb776ff 100644 --- a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc @@ -138,9 +138,36 @@ CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) { return unique_cache_enable; } +template +void MemCopyFromHostToCache(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max, + size_t hashmap_size, size_t col_size) { + auto host_data = static_cast(host_addr); + auto cache_data = static_cast(cache_addr); + auto hashmap_data = static_cast *>(hashmap_addr); + // default param type float + size_t param_type_size = 4; + size_t single_col_bytes = param_type_size * col_size; + for (size_t i = 0; i < hashmap_size; ++i) { + if (!hashmap_data[i].IsEmpty()) { + size_t host_offset = single_col_bytes * hashmap_data[i].key_; + size_t cache_offset = single_col_bytes * hashmap_data[i].value_; + if (host_offset + single_col_bytes <= host_max) { + auto ret = + memcpy_s(cache_data + cache_offset, cache_max - cache_offset, host_data + host_offset, single_col_bytes); + if (ret != 0) { + MS_LOG(EXCEPTION) << "Memcpy failed."; + } + } + } + } + MS_LOG(INFO) << "Memcpy from cache to host success!"; +} + void BindAndInitCacheTensor(const ParamMap ¶m_pair_list, const ParameterPtr &hashmap) { auto hashmap_tensor_value = hashmap->default_param(); auto hashmap_tensor = hashmap_tensor_value->cast>(); + auto hashmap_size = hashmap_tensor->shape_c()[0]; + auto hashmap_data_type = hashmap_tensor->data_type(); for (auto &ele : param_pair_list) { auto host_tensor_value = ele.second->default_param(); auto host_tensor = host_tensor_value->cast>(); @@ -151,11 +178,24 @@ void BindAndInitCacheTensor(const ParamMap ¶m_pair_list, const ParameterPtr host_tensor->set_cache_enable(true); host_tensor->set_hashmap_tensor_ptr(hashmap_tensor); host_tensor->set_cache_tensor_ptr(cache_tensor); + // init cache tensor data - auto cache_byte_size = cache_tensor->Size(); - int ret = memcpy_s(cache_tensor->data_c(), cache_byte_size, host_tensor->data_c(), cache_byte_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "Memcpy failed."; + auto host_shape = host_tensor->shape_c(); + auto cache_shape = cache_tensor->shape_c(); + if (host_shape.size() != 2 && host_shape.size() != 2 && host_shape[1] != cache_shape[1]) { + MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid." + << "host shape:" << host_shape << ", cache shape:" << cache_shape; + } + auto host_data_max_size = host_tensor->Size(); + auto cache_data_max_size = cache_tensor->Size(); + if (hashmap_data_type == TypeId::kNumberTypeInt32) { + MemCopyFromHostToCache(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(), + host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]); + } else if (hashmap_data_type == TypeId::kNumberTypeInt64) { + MemCopyFromHostToCache(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(), + host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]); + } else { + MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64."; } } } @@ -320,7 +360,7 @@ void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input MS_EXCEPTION_IF_NULL(outputs); } -AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr indices) { +AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) { MS_EXCEPTION_IF_NULL(graph); PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup; emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); @@ -376,13 +416,16 @@ NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_ return node_pair_list; } -AnfNodePtr CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, - const AnfNodePtr &behind_node) { +void CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node) { // Create control depend MS_EXCEPTION_IF_NULL(main_graph); - AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node}; - auto control_depend_cnode = main_graph->NewCNode(cd_inputs); - return control_depend_cnode; + auto manager = main_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node}; + auto depend_cnode = main_graph->NewCNode(cd_inputs); + if (!manager->Replace(behind_node, depend_cnode)) { + MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed."; + } } AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector &invalid_nodes, @@ -402,9 +445,12 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param CNodePtrList sparse_gather_v2_with_cache; for (size_t i = 0; i < cnodes_size; ++i) { if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) { - auto param_node = cnodes[i]->input(1)->cast(); - if (param_set.find(param_node) != param_set.end()) { - sparse_gather_v2_with_cache.push_back(cnodes[i]); + auto load_node = cnodes[i]->input(1); + if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) { + auto param_node = load_node->cast()->input(1)->cast(); + if (param_set.find(param_node) != param_set.end()) { + sparse_gather_v2_with_cache.push_back(cnodes[i]); + } } } } @@ -534,9 +580,9 @@ AnfNodePtr FindCNodeOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, in MS_LOG(EXCEPTION) << "Can't not find " << node->DebugString() << ", outputs:" << index; } -AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map, - const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx, - const AnfNodePtr &sparse_gatherv2_indices) { +void ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map, + const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx, + const AnfNodePtr &sparse_gatherv2_indices) { auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); auto node_users = manager->node_users(); @@ -546,8 +592,7 @@ AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ auto user_set = node_users[ele.first]; auto assign_status = CreateAssign(graph, ele.second, ele.first); for (auto user_node : user_set) { - auto control_depend = CreateControlDepend(graph, user_node.first, assign_status); - control_depend_list.emplace_back(control_depend); + CreateControlDepend(graph, user_node.first, assign_status); } if (!manager->Replace(ele.first, ele.second)) { MS_LOG(EXCEPTION) << "pipe param: " << ele.first->DebugString() << ", replace node failed."; @@ -558,13 +603,11 @@ AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ auto dynamic_assgin_status = CreateAssign(graph, cache_idx_param, cache_idx, true); auto indices_user_set = node_users[sparse_gatherv2_indices]; for (auto &user_node : indices_user_set) { - auto control_depend = CreateControlDepend(graph, user_node.first, dynamic_assgin_status); - control_depend_list.emplace_back(control_depend); + CreateControlDepend(graph, user_node.first, dynamic_assgin_status); } if (!manager->Replace(sparse_gatherv2_indices, cache_idx_param)) { MS_LOG(EXCEPTION) << "cache idx param: " << cache_idx_param->DebugString() << ", replace node failed."; } - return control_depend_list; } void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNodePtrList &cnodes, @@ -594,10 +637,9 @@ void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNode MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed"; } for (auto &ele : node_pair_list) { - std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(), - std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) { - return CreateControlDepend(graph, ele.first, sparse_gatherv2); - }); + for (auto &gather_op : sparse_gatherv2_with_cache) { + CreateControlDepend(graph, ele.first, gather_op); + } invalid_nodes.emplace_back(ele.second); } } else { @@ -608,9 +650,7 @@ void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNode RemoveOriginParamFromSet(unique_node, &no_ref_params); auto no_ref_param_map = CreateOtherPipeParams(graph, no_ref_params); no_ref_param_map[unique_index_reverse] = unique_index_param; - auto control_depend_list = ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, - sparse_gatherv2_with_cache[0]->input(2)); - std::copy(control_depend_list.begin(), control_depend_list.end(), std::back_inserter(invalid_nodes)); + ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, sparse_gatherv2_with_cache[0]->input(2)); std::transform(node_pair_list.begin(), node_pair_list.end(), std::back_inserter(invalid_nodes), [](const std::pair &pair) { return pair.second; }); } @@ -646,8 +686,7 @@ void CacheEmbeddingForEval(const FuncGraphPtr &graph, const CNodePtrList &cnodes for (auto &ele : sparse_gatherv2_with_cache) { auto anf_ele = ele->cast(); auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele); - auto param = ele->input(1)->cast(); - auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices); + auto embedding_lookup = CreateEmbeddingLookup(graph, ele->input(1), indices); if (!manager->Replace(gatherv2, embedding_lookup)) { MS_LOG(EXCEPTION) << "Depend replace node failed"; } diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 225b5c6d2d..24d0e654de 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -338,7 +338,7 @@ class Tensor(Tensor_): self.init_check() return Tensor_.asnumpy(self) - def _flush_from_cache(self): + def flush_from_cache(self): """Flush cache data to host if tensor is cache enable.""" self.init_check() Tensor_._flush_from_cache(self) diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index 18cfec30d2..360260f348 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -27,7 +27,7 @@ from mindspore.train._utils import _make_directory from mindspore.train.serialization import save_checkpoint, _save_graph from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank from ._callback import Callback, set_cur_net - +from ...common.tensor import Tensor _cur_dir = os.getcwd() _save_dir = _cur_dir @@ -295,6 +295,10 @@ class ModelCheckpoint(Callback): """ cb_params = run_context.original_args() _to_save_last_ckpt = True + + # if param is cache enable, flush data from cache to host before epoch end + self._flush_from_cache(cb_params) + self._save_ckpt(cb_params, _to_save_last_ckpt) thread_list = threading.enumerate() @@ -359,6 +363,13 @@ class ModelCheckpoint(Callback): self._latest_ckpt_file_name = cur_file + def _flush_from_cache(self, cb_params): + """Flush cache data to host if tensor is cache enable.""" + params = cb_params.train_network.get_parameters() + for param in params: + if param.cache_enable: + Tensor(param).flush_from_cache() + @property def latest_ckpt_file_name(self): """Return the latest checkpoint path and file name.""" diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 56096d8dac..1c07768225 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -799,6 +799,6 @@ class Model: params = cb_params.train_network.get_parameters() for param in params: if param.cache_enable: - Tensor(param)._flush_from_cache() + Tensor(param).flush_from_cache() __all__ = ["Model"]