|
|
|
@@ -138,9 +138,36 @@ CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) { |
|
|
|
return unique_cache_enable; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void MemCopyFromHostToCache(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max, |
|
|
|
size_t hashmap_size, size_t col_size) { |
|
|
|
auto host_data = static_cast<char *>(host_addr); |
|
|
|
auto cache_data = static_cast<char *>(cache_addr); |
|
|
|
auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr); |
|
|
|
// default param type float |
|
|
|
size_t param_type_size = 4; |
|
|
|
size_t single_col_bytes = param_type_size * col_size; |
|
|
|
for (size_t i = 0; i < hashmap_size; ++i) { |
|
|
|
if (!hashmap_data[i].IsEmpty()) { |
|
|
|
size_t host_offset = single_col_bytes * hashmap_data[i].key_; |
|
|
|
size_t cache_offset = single_col_bytes * hashmap_data[i].value_; |
|
|
|
if (host_offset + single_col_bytes <= host_max) { |
|
|
|
auto ret = |
|
|
|
memcpy_s(cache_data + cache_offset, cache_max - cache_offset, host_data + host_offset, single_col_bytes); |
|
|
|
if (ret != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Memcpy failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Memcpy from cache to host success!"; |
|
|
|
} |
|
|
|
|
|
|
|
void BindAndInitCacheTensor(const ParamMap ¶m_pair_list, const ParameterPtr &hashmap) { |
|
|
|
auto hashmap_tensor_value = hashmap->default_param(); |
|
|
|
auto hashmap_tensor = hashmap_tensor_value->cast<std::shared_ptr<tensor::Tensor>>(); |
|
|
|
auto hashmap_size = hashmap_tensor->shape_c()[0]; |
|
|
|
auto hashmap_data_type = hashmap_tensor->data_type(); |
|
|
|
for (auto &ele : param_pair_list) { |
|
|
|
auto host_tensor_value = ele.second->default_param(); |
|
|
|
auto host_tensor = host_tensor_value->cast<std::shared_ptr<tensor::Tensor>>(); |
|
|
|
@@ -151,11 +178,24 @@ void BindAndInitCacheTensor(const ParamMap ¶m_pair_list, const ParameterPtr |
|
|
|
host_tensor->set_cache_enable(true); |
|
|
|
host_tensor->set_hashmap_tensor_ptr(hashmap_tensor); |
|
|
|
host_tensor->set_cache_tensor_ptr(cache_tensor); |
|
|
|
|
|
|
|
// init cache tensor data |
|
|
|
auto cache_byte_size = cache_tensor->Size(); |
|
|
|
int ret = memcpy_s(cache_tensor->data_c(), cache_byte_size, host_tensor->data_c(), cache_byte_size); |
|
|
|
if (ret != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Memcpy failed."; |
|
|
|
auto host_shape = host_tensor->shape_c(); |
|
|
|
auto cache_shape = cache_tensor->shape_c(); |
|
|
|
if (host_shape.size() != 2 && host_shape.size() != 2 && host_shape[1] != cache_shape[1]) { |
|
|
|
MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid." |
|
|
|
<< "host shape:" << host_shape << ", cache shape:" << cache_shape; |
|
|
|
} |
|
|
|
auto host_data_max_size = host_tensor->Size(); |
|
|
|
auto cache_data_max_size = cache_tensor->Size(); |
|
|
|
if (hashmap_data_type == TypeId::kNumberTypeInt32) { |
|
|
|
MemCopyFromHostToCache<int32_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(), |
|
|
|
host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]); |
|
|
|
} else if (hashmap_data_type == TypeId::kNumberTypeInt64) { |
|
|
|
MemCopyFromHostToCache<int32_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(), |
|
|
|
host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64."; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -320,7 +360,7 @@ void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input |
|
|
|
MS_EXCEPTION_IF_NULL(outputs); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr indices) { |
|
|
|
AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup; |
|
|
|
emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); |
|
|
|
@@ -376,13 +416,16 @@ NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_ |
|
|
|
return node_pair_list; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, |
|
|
|
const AnfNodePtr &behind_node) { |
|
|
|
void CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node) { |
|
|
|
// Create control depend |
|
|
|
MS_EXCEPTION_IF_NULL(main_graph); |
|
|
|
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node}; |
|
|
|
auto control_depend_cnode = main_graph->NewCNode(cd_inputs); |
|
|
|
return control_depend_cnode; |
|
|
|
auto manager = main_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node}; |
|
|
|
auto depend_cnode = main_graph->NewCNode(cd_inputs); |
|
|
|
if (!manager->Replace(behind_node, depend_cnode)) { |
|
|
|
MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &invalid_nodes, |
|
|
|
@@ -402,9 +445,12 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param |
|
|
|
CNodePtrList sparse_gather_v2_with_cache; |
|
|
|
for (size_t i = 0; i < cnodes_size; ++i) { |
|
|
|
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) { |
|
|
|
auto param_node = cnodes[i]->input(1)->cast<ParameterPtr>(); |
|
|
|
if (param_set.find(param_node) != param_set.end()) { |
|
|
|
sparse_gather_v2_with_cache.push_back(cnodes[i]); |
|
|
|
auto load_node = cnodes[i]->input(1); |
|
|
|
if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) { |
|
|
|
auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>(); |
|
|
|
if (param_set.find(param_node) != param_set.end()) { |
|
|
|
sparse_gather_v2_with_cache.push_back(cnodes[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -534,9 +580,9 @@ AnfNodePtr FindCNodeOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, in |
|
|
|
MS_LOG(EXCEPTION) << "Can't not find " << node->DebugString() << ", outputs:" << index; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map, |
|
|
|
const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx, |
|
|
|
const AnfNodePtr &sparse_gatherv2_indices) { |
|
|
|
void ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map, |
|
|
|
const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx, |
|
|
|
const AnfNodePtr &sparse_gatherv2_indices) { |
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto node_users = manager->node_users(); |
|
|
|
@@ -546,8 +592,7 @@ AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ |
|
|
|
auto user_set = node_users[ele.first]; |
|
|
|
auto assign_status = CreateAssign(graph, ele.second, ele.first); |
|
|
|
for (auto user_node : user_set) { |
|
|
|
auto control_depend = CreateControlDepend(graph, user_node.first, assign_status); |
|
|
|
control_depend_list.emplace_back(control_depend); |
|
|
|
CreateControlDepend(graph, user_node.first, assign_status); |
|
|
|
} |
|
|
|
if (!manager->Replace(ele.first, ele.second)) { |
|
|
|
MS_LOG(EXCEPTION) << "pipe param: " << ele.first->DebugString() << ", replace node failed."; |
|
|
|
@@ -558,13 +603,11 @@ AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ |
|
|
|
auto dynamic_assgin_status = CreateAssign(graph, cache_idx_param, cache_idx, true); |
|
|
|
auto indices_user_set = node_users[sparse_gatherv2_indices]; |
|
|
|
for (auto &user_node : indices_user_set) { |
|
|
|
auto control_depend = CreateControlDepend(graph, user_node.first, dynamic_assgin_status); |
|
|
|
control_depend_list.emplace_back(control_depend); |
|
|
|
CreateControlDepend(graph, user_node.first, dynamic_assgin_status); |
|
|
|
} |
|
|
|
if (!manager->Replace(sparse_gatherv2_indices, cache_idx_param)) { |
|
|
|
MS_LOG(EXCEPTION) << "cache idx param: " << cache_idx_param->DebugString() << ", replace node failed."; |
|
|
|
} |
|
|
|
return control_depend_list; |
|
|
|
} |
|
|
|
|
|
|
|
void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNodePtrList &cnodes, |
|
|
|
@@ -594,10 +637,9 @@ void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNode |
|
|
|
MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed"; |
|
|
|
} |
|
|
|
for (auto &ele : node_pair_list) { |
|
|
|
std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(), |
|
|
|
std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) { |
|
|
|
return CreateControlDepend(graph, ele.first, sparse_gatherv2); |
|
|
|
}); |
|
|
|
for (auto &gather_op : sparse_gatherv2_with_cache) { |
|
|
|
CreateControlDepend(graph, ele.first, gather_op); |
|
|
|
} |
|
|
|
invalid_nodes.emplace_back(ele.second); |
|
|
|
} |
|
|
|
} else { |
|
|
|
@@ -608,9 +650,7 @@ void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNode |
|
|
|
RemoveOriginParamFromSet(unique_node, &no_ref_params); |
|
|
|
auto no_ref_param_map = CreateOtherPipeParams(graph, no_ref_params); |
|
|
|
no_ref_param_map[unique_index_reverse] = unique_index_param; |
|
|
|
auto control_depend_list = ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, |
|
|
|
sparse_gatherv2_with_cache[0]->input(2)); |
|
|
|
std::copy(control_depend_list.begin(), control_depend_list.end(), std::back_inserter(invalid_nodes)); |
|
|
|
ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, sparse_gatherv2_with_cache[0]->input(2)); |
|
|
|
std::transform(node_pair_list.begin(), node_pair_list.end(), std::back_inserter(invalid_nodes), |
|
|
|
[](const std::pair<AnfNodePtr, AnfNodePtr> &pair) { return pair.second; }); |
|
|
|
} |
|
|
|
@@ -646,8 +686,7 @@ void CacheEmbeddingForEval(const FuncGraphPtr &graph, const CNodePtrList &cnodes |
|
|
|
for (auto &ele : sparse_gatherv2_with_cache) { |
|
|
|
auto anf_ele = ele->cast<AnfNodePtr>(); |
|
|
|
auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele); |
|
|
|
auto param = ele->input(1)->cast<ParameterPtr>(); |
|
|
|
auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices); |
|
|
|
auto embedding_lookup = CreateEmbeddingLookup(graph, ele->input(1), indices); |
|
|
|
if (!manager->Replace(gatherv2, embedding_lookup)) { |
|
|
|
MS_LOG(EXCEPTION) << "Depend replace node failed"; |
|
|
|
} |
|
|
|
|