Browse Source

!12496 fix bugs for embedding cache

From: @fangzehua
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
8d323792b2
4 changed files with 83 additions and 33 deletions
  1. +69
    -30
      mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc
  2. +1
    -1
      mindspore/common/tensor.py
  3. +12
    -1
      mindspore/train/callback/_checkpoint.py
  4. +1
    -1
      mindspore/train/model.py

+ 69
- 30
mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc View File

@@ -138,9 +138,36 @@ CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) {
return unique_cache_enable;
}

template <typename T>
void MemCopyFromHostToCache(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max,
size_t hashmap_size, size_t col_size) {
auto host_data = static_cast<char *>(host_addr);
auto cache_data = static_cast<char *>(cache_addr);
auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr);
// default param type float
size_t param_type_size = 4;
size_t single_col_bytes = param_type_size * col_size;
for (size_t i = 0; i < hashmap_size; ++i) {
if (!hashmap_data[i].IsEmpty()) {
size_t host_offset = single_col_bytes * hashmap_data[i].key_;
size_t cache_offset = single_col_bytes * hashmap_data[i].value_;
if (host_offset + single_col_bytes <= host_max) {
auto ret =
memcpy_s(cache_data + cache_offset, cache_max - cache_offset, host_data + host_offset, single_col_bytes);
if (ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy failed.";
}
}
}
}
MS_LOG(INFO) << "Memcpy from cache to host success!";
}

void BindAndInitCacheTensor(const ParamMap &param_pair_list, const ParameterPtr &hashmap) {
auto hashmap_tensor_value = hashmap->default_param();
auto hashmap_tensor = hashmap_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
auto hashmap_size = hashmap_tensor->shape_c()[0];
auto hashmap_data_type = hashmap_tensor->data_type();
for (auto &ele : param_pair_list) {
auto host_tensor_value = ele.second->default_param();
auto host_tensor = host_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
@@ -151,11 +178,24 @@ void BindAndInitCacheTensor(const ParamMap &param_pair_list, const ParameterPtr
host_tensor->set_cache_enable(true);
host_tensor->set_hashmap_tensor_ptr(hashmap_tensor);
host_tensor->set_cache_tensor_ptr(cache_tensor);

// init cache tensor data
auto cache_byte_size = cache_tensor->Size();
int ret = memcpy_s(cache_tensor->data_c(), cache_byte_size, host_tensor->data_c(), cache_byte_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy failed.";
auto host_shape = host_tensor->shape_c();
auto cache_shape = cache_tensor->shape_c();
if (host_shape.size() != 2 && host_shape.size() != 2 && host_shape[1] != cache_shape[1]) {
MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid."
<< "host shape:" << host_shape << ", cache shape:" << cache_shape;
}
auto host_data_max_size = host_tensor->Size();
auto cache_data_max_size = cache_tensor->Size();
if (hashmap_data_type == TypeId::kNumberTypeInt32) {
MemCopyFromHostToCache<int32_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(),
host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
} else if (hashmap_data_type == TypeId::kNumberTypeInt64) {
MemCopyFromHostToCache<int32_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(),
host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
} else {
MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64.";
}
}
}
@@ -320,7 +360,7 @@ void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input
MS_EXCEPTION_IF_NULL(outputs);
}

AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr indices) {
AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) {
MS_EXCEPTION_IF_NULL(graph);
PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup;
emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
@@ -376,13 +416,16 @@ NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_
return node_pair_list;
}

AnfNodePtr CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node,
const AnfNodePtr &behind_node) {
void CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node) {
// Create control depend
MS_EXCEPTION_IF_NULL(main_graph);
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node};
auto control_depend_cnode = main_graph->NewCNode(cd_inputs);
return control_depend_cnode;
auto manager = main_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node};
auto depend_cnode = main_graph->NewCNode(cd_inputs);
if (!manager->Replace(behind_node, depend_cnode)) {
MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed.";
}
}

AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &invalid_nodes,
@@ -402,9 +445,12 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param
CNodePtrList sparse_gather_v2_with_cache;
for (size_t i = 0; i < cnodes_size; ++i) {
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) {
auto param_node = cnodes[i]->input(1)->cast<ParameterPtr>();
if (param_set.find(param_node) != param_set.end()) {
sparse_gather_v2_with_cache.push_back(cnodes[i]);
auto load_node = cnodes[i]->input(1);
if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>();
if (param_set.find(param_node) != param_set.end()) {
sparse_gather_v2_with_cache.push_back(cnodes[i]);
}
}
}
}
@@ -534,9 +580,9 @@ AnfNodePtr FindCNodeOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, in
MS_LOG(EXCEPTION) << "Can't not find " << node->DebugString() << ", outputs:" << index;
}

AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map,
const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx,
const AnfNodePtr &sparse_gatherv2_indices) {
void ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map,
const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx,
const AnfNodePtr &sparse_gatherv2_indices) {
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_users = manager->node_users();
@@ -546,8 +592,7 @@ AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_
auto user_set = node_users[ele.first];
auto assign_status = CreateAssign(graph, ele.second, ele.first);
for (auto user_node : user_set) {
auto control_depend = CreateControlDepend(graph, user_node.first, assign_status);
control_depend_list.emplace_back(control_depend);
CreateControlDepend(graph, user_node.first, assign_status);
}
if (!manager->Replace(ele.first, ele.second)) {
MS_LOG(EXCEPTION) << "pipe param: " << ele.first->DebugString() << ", replace node failed.";
@@ -558,13 +603,11 @@ AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_
auto dynamic_assgin_status = CreateAssign(graph, cache_idx_param, cache_idx, true);
auto indices_user_set = node_users[sparse_gatherv2_indices];
for (auto &user_node : indices_user_set) {
auto control_depend = CreateControlDepend(graph, user_node.first, dynamic_assgin_status);
control_depend_list.emplace_back(control_depend);
CreateControlDepend(graph, user_node.first, dynamic_assgin_status);
}
if (!manager->Replace(sparse_gatherv2_indices, cache_idx_param)) {
MS_LOG(EXCEPTION) << "cache idx param: " << cache_idx_param->DebugString() << ", replace node failed.";
}
return control_depend_list;
}

void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNodePtrList &cnodes,
@@ -594,10 +637,9 @@ void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNode
MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed";
}
for (auto &ele : node_pair_list) {
std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(),
std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) {
return CreateControlDepend(graph, ele.first, sparse_gatherv2);
});
for (auto &gather_op : sparse_gatherv2_with_cache) {
CreateControlDepend(graph, ele.first, gather_op);
}
invalid_nodes.emplace_back(ele.second);
}
} else {
@@ -608,9 +650,7 @@ void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNode
RemoveOriginParamFromSet(unique_node, &no_ref_params);
auto no_ref_param_map = CreateOtherPipeParams(graph, no_ref_params);
no_ref_param_map[unique_index_reverse] = unique_index_param;
auto control_depend_list = ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx,
sparse_gatherv2_with_cache[0]->input(2));
std::copy(control_depend_list.begin(), control_depend_list.end(), std::back_inserter(invalid_nodes));
ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, sparse_gatherv2_with_cache[0]->input(2));
std::transform(node_pair_list.begin(), node_pair_list.end(), std::back_inserter(invalid_nodes),
[](const std::pair<AnfNodePtr, AnfNodePtr> &pair) { return pair.second; });
}
@@ -646,8 +686,7 @@ void CacheEmbeddingForEval(const FuncGraphPtr &graph, const CNodePtrList &cnodes
for (auto &ele : sparse_gatherv2_with_cache) {
auto anf_ele = ele->cast<AnfNodePtr>();
auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele);
auto param = ele->input(1)->cast<ParameterPtr>();
auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices);
auto embedding_lookup = CreateEmbeddingLookup(graph, ele->input(1), indices);
if (!manager->Replace(gatherv2, embedding_lookup)) {
MS_LOG(EXCEPTION) << "Depend replace node failed";
}


+ 1
- 1
mindspore/common/tensor.py View File

@@ -338,7 +338,7 @@ class Tensor(Tensor_):
self.init_check()
return Tensor_.asnumpy(self)

def _flush_from_cache(self):
def flush_from_cache(self):
"""Flush cache data to host if tensor is cache enable."""
self.init_check()
Tensor_._flush_from_cache(self)


+ 12
- 1
mindspore/train/callback/_checkpoint.py View File

@@ -27,7 +27,7 @@ from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
from ._callback import Callback, set_cur_net
from ...common.tensor import Tensor

_cur_dir = os.getcwd()
_save_dir = _cur_dir
@@ -295,6 +295,10 @@ class ModelCheckpoint(Callback):
"""
cb_params = run_context.original_args()
_to_save_last_ckpt = True

# if param is cache enable, flush data from cache to host before epoch end
self._flush_from_cache(cb_params)

self._save_ckpt(cb_params, _to_save_last_ckpt)

thread_list = threading.enumerate()
@@ -359,6 +363,13 @@ class ModelCheckpoint(Callback):

self._latest_ckpt_file_name = cur_file

def _flush_from_cache(self, cb_params):
"""Flush cache data to host if tensor is cache enable."""
params = cb_params.train_network.get_parameters()
for param in params:
if param.cache_enable:
Tensor(param).flush_from_cache()

@property
def latest_ckpt_file_name(self):
"""Return the latest checkpoint path and file name."""


+ 1
- 1
mindspore/train/model.py View File

@@ -799,6 +799,6 @@ class Model:
params = cb_params.train_network.get_parameters()
for param in params:
if param.cache_enable:
Tensor(param)._flush_from_cache()
Tensor(param).flush_from_cache()

__all__ = ["Model"]

Loading…
Cancel
Save