|
|
|
@@ -18,6 +18,7 @@ |
|
|
|
#include <random> |
|
|
|
#include <vector> |
|
|
|
#include <list> |
|
|
|
#include <queue> |
|
|
|
#include <utility> |
|
|
|
#include <memory> |
|
|
|
#include <unordered_map> |
|
|
|
@@ -33,6 +34,8 @@ namespace parallel { |
|
|
|
using ParamMap = std::unordered_map<ParameterPtr, ParameterPtr>; |
|
|
|
using ParamSet = std::unordered_set<ParameterPtr>; |
|
|
|
using NodePairList = std::vector<std::pair<AnfNodePtr, AnfNodePtr>>; |
|
|
|
using AnfMap = std::unordered_map<AnfNodePtr, AnfNodePtr>; |
|
|
|
using AnfSet = std::unordered_set<AnfNodePtr>; |
|
|
|
|
|
|
|
ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet ¶meter_cache_enable_set) { |
|
|
|
ParamMap cache_host_params_map; |
|
|
|
@@ -408,6 +411,7 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param |
|
|
|
if (sparse_gather_v2_with_cache.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param."; |
|
|
|
} |
|
|
|
|
|
|
|
auto indices = sparse_gather_v2_with_cache[0]->input(2); |
|
|
|
for (auto &ele : sparse_gather_v2_with_cache) { |
|
|
|
if (ele->input(2) != indices) { |
|
|
|
@@ -433,13 +437,227 @@ AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNo |
|
|
|
return gatherv2_nodes[0]; |
|
|
|
} |
|
|
|
|
|
|
|
void AddCacheEmbedding(const FuncGraphPtr &graph) { |
|
|
|
AnfSet FindNoRefParams(const FuncGraphPtr &graph) { |
|
|
|
AnfSet no_ref_params; |
|
|
|
auto params = graph->parameters(); |
|
|
|
for (auto &anf_param : params) { |
|
|
|
auto param = anf_param->cast<ParameterPtr>(); |
|
|
|
if (!param->has_default()) { |
|
|
|
MS_LOG(INFO) << param->DebugString() << " has no default"; |
|
|
|
no_ref_params.insert(anf_param); |
|
|
|
} |
|
|
|
} |
|
|
|
return no_ref_params; |
|
|
|
} |
|
|
|
|
|
|
|
void RemoveOriginParamFromSet(const CNodePtr &unique_node, AnfSet *no_ref_params) { |
|
|
|
std::queue<CNodePtr> que; |
|
|
|
que.push(unique_node); |
|
|
|
while (!que.empty()) { |
|
|
|
auto node = que.front(); |
|
|
|
que.pop(); |
|
|
|
auto node_inputs = node->inputs(); |
|
|
|
for (auto &input : node_inputs) { |
|
|
|
if (input->isa<CNode>()) { |
|
|
|
que.push(input->cast<CNodePtr>()); |
|
|
|
} else if (input->isa<Parameter>()) { |
|
|
|
int num = no_ref_params->erase(input); |
|
|
|
if (num > 0) { |
|
|
|
MS_LOG(INFO) << "Erase unique_node input from set success."; |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "Can not find any parameter that use by Unique."; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &ori_input, const std::string &name) { |
|
|
|
auto ori_input_type = ori_input->Type(); |
|
|
|
auto ori_input_element_type = ori_input_type->cast<mindspore::TensorTypePtr>()->element(); |
|
|
|
auto ori_input_type_id = ori_input_element_type->type_id(); |
|
|
|
auto ori_input_shp = ori_input->Shape(); |
|
|
|
auto input_shp = ori_input_shp->cast<abstract::ShapePtr>(); |
|
|
|
auto input_shape = input_shp->shape(); |
|
|
|
auto new_tensor = std::make_shared<tensor::Tensor>(ori_input_type_id, input_shape); |
|
|
|
ParamInfoPtr new_param_info = std::make_shared<ParamInfo>(); |
|
|
|
auto new_param_name = name + "_pipe"; |
|
|
|
new_param_info->set_name(new_param_name); |
|
|
|
new_tensor->set_param_info(new_param_info); |
|
|
|
auto new_param = graph->AddWeightParameter(new_param_name); |
|
|
|
new_param->set_default_param(MakeValue(new_tensor)); |
|
|
|
auto abs_tensor = new_tensor->ToAbstract(); |
|
|
|
new_param->set_abstract(abs_tensor); |
|
|
|
return new_param->cast<AnfNodePtr>(); |
|
|
|
} |
|
|
|
|
|
|
|
AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) { |
|
|
|
AnfMap no_ref_pipe_param_map; |
|
|
|
for (auto ¶m : no_ref_params) { |
|
|
|
auto ori_param = param->cast<ParameterPtr>(); |
|
|
|
auto ori_name = ori_param->name(); |
|
|
|
auto new_param = CreateOutputNodeParam(graph, param, ori_name); |
|
|
|
no_ref_pipe_param_map[param] = new_param; |
|
|
|
} |
|
|
|
return no_ref_pipe_param_map; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateAssign(const FuncGraphPtr &graph, const AnfNodePtr &res_param, const AnfNodePtr &src_param, |
|
|
|
bool is_dynamic = false) { |
|
|
|
auto assign_prim = prim::kPrimAssign; |
|
|
|
if (is_dynamic) { |
|
|
|
assign_prim = prim::kPrimDynamicAssign; |
|
|
|
assign_prim->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> assign_nodes{NewValueNode(assign_prim), res_param, src_param}; |
|
|
|
auto assign_status = graph->NewCNode(assign_nodes); |
|
|
|
return assign_status; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr FindCNodeOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t index) { |
|
|
|
auto manager = graph->manager(); |
|
|
|
auto node_users = manager->node_users()[node]; |
|
|
|
for (auto &node_user : node_users) { |
|
|
|
if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { |
|
|
|
auto cnode = node_user.first->cast<CNodePtr>(); |
|
|
|
auto node_index = cnode->input(2); |
|
|
|
if (node_index->isa<ValueNode>()) { |
|
|
|
auto value_node = node_index->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
auto item_idx = GetValue<int64_t>(value_node->value()); |
|
|
|
if (item_idx == index) { |
|
|
|
return node_user.first; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "Can't not find " << node->DebugString() << ", outputs:" << index; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map, |
|
|
|
const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx, |
|
|
|
const AnfNodePtr &sparse_gatherv2_indices) { |
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto node_users = manager->node_users(); |
|
|
|
AnfNodePtrList control_depend_list; |
|
|
|
// add other no ref pipe param and unique index dense |
|
|
|
for (auto &ele : no_ref_pipe_param_map) { |
|
|
|
auto user_set = node_users[ele.first]; |
|
|
|
auto assign_status = CreateAssign(graph, ele.second, ele.first); |
|
|
|
for (auto user_node : user_set) { |
|
|
|
auto control_depend = CreateControlDepend(graph, user_node.first, assign_status); |
|
|
|
control_depend_list.emplace_back(control_depend); |
|
|
|
} |
|
|
|
if (!manager->Replace(ele.first, ele.second)) { |
|
|
|
MS_LOG(EXCEPTION) << "pipe param: " << ele.first->DebugString() << ", replace node failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// add cache idx param |
|
|
|
auto dynamic_assgin_status = CreateAssign(graph, cache_idx_param, cache_idx, true); |
|
|
|
auto indices_user_set = node_users[sparse_gatherv2_indices]; |
|
|
|
for (auto &user_node : indices_user_set) { |
|
|
|
auto control_depend = CreateControlDepend(graph, user_node.first, dynamic_assgin_status); |
|
|
|
control_depend_list.emplace_back(control_depend); |
|
|
|
} |
|
|
|
if (!manager->Replace(sparse_gatherv2_indices, cache_idx_param)) { |
|
|
|
MS_LOG(EXCEPTION) << "cache idx param: " << cache_idx_param->DebugString() << ", replace node failed."; |
|
|
|
} |
|
|
|
return control_depend_list; |
|
|
|
} |
|
|
|
|
|
|
|
void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNodePtrList &cnodes, |
|
|
|
const CNodePtr &unique_node, const ParamSet ¶m_cache_enable_set) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::list<CNodePtr> orders = graph->GetOrderedCnodes(); |
|
|
|
CNodePtrList cnodes(orders.begin(), orders.end()); |
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
size_t cnodes_size = cnodes.size(); |
|
|
|
auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set); |
|
|
|
auto param_set = MapKeysToSet(cache_host_params_map); |
|
|
|
ReplaceCacheParams(graph, cache_host_params_map); |
|
|
|
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); |
|
|
|
MS_LOG(INFO) << "Graph is set cache enable."; |
|
|
|
|
|
|
|
CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set); |
|
|
|
auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0); |
|
|
|
auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map); |
|
|
|
|
|
|
|
AnfNodePtrList map_cache_idx_node_outputs; |
|
|
|
CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs); |
|
|
|
|
|
|
|
auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs); |
|
|
|
AnfNodePtrList invalid_nodes; |
|
|
|
auto cache_idx = map_cache_idx_node_outputs[0]; |
|
|
|
if (!is_pipe) { |
|
|
|
if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), cache_idx)) { |
|
|
|
MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed"; |
|
|
|
} |
|
|
|
for (auto &ele : node_pair_list) { |
|
|
|
std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(), |
|
|
|
std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) { |
|
|
|
return CreateControlDepend(graph, ele.first, sparse_gatherv2); |
|
|
|
}); |
|
|
|
invalid_nodes.emplace_back(ele.second); |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto cache_idx_param = CreateOutputNodeParam(graph, unique_node->input(1), std::string("cache_idx")); |
|
|
|
auto unique_index_reverse = FindCNodeOutput(graph, unique_node, 1); |
|
|
|
auto unique_index_param = CreateOutputNodeParam(graph, unique_index_reverse, std::string("index_dense")); |
|
|
|
auto no_ref_params = FindNoRefParams(graph); |
|
|
|
RemoveOriginParamFromSet(unique_node, &no_ref_params); |
|
|
|
auto no_ref_param_map = CreateOtherPipeParams(graph, no_ref_params); |
|
|
|
no_ref_param_map[unique_index_reverse] = unique_index_param; |
|
|
|
auto control_depend_list = ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, |
|
|
|
sparse_gatherv2_with_cache[0]->input(2)); |
|
|
|
std::copy(control_depend_list.begin(), control_depend_list.end(), std::back_inserter(invalid_nodes)); |
|
|
|
std::transform(node_pair_list.begin(), node_pair_list.end(), std::back_inserter(invalid_nodes), |
|
|
|
[](const std::pair<AnfNodePtr, AnfNodePtr> &pair) { return pair.second; }); |
|
|
|
} |
|
|
|
AnfNodePtr last_node = cnodes[cnodes_size - 1]; |
|
|
|
CNodePtr return_node; |
|
|
|
if (last_node->isa<CNode>()) { |
|
|
|
return_node = last_node->cast<CNodePtr>(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { |
|
|
|
MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; |
|
|
|
} |
|
|
|
if (return_node->inputs().size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; |
|
|
|
} |
|
|
|
|
|
|
|
auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1)); |
|
|
|
if (!manager->Replace(return_node->input(1), depend_node)) { |
|
|
|
MS_LOG(EXCEPTION) << "Depend replace node failed"; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void CacheEmbeddingForEval(const FuncGraphPtr &graph, const CNodePtrList &cnodes, const CNodePtr &unique_node, |
|
|
|
const ParamSet ¶m_cache_enable_set) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); |
|
|
|
MS_LOG(INFO) << "Graph is set cache enable."; |
|
|
|
// replace GatherV2 to EmbeddingLookupCPU |
|
|
|
auto indices = unique_node->input(1); |
|
|
|
auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set); |
|
|
|
for (auto &ele : sparse_gatherv2_with_cache) { |
|
|
|
auto anf_ele = ele->cast<AnfNodePtr>(); |
|
|
|
auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele); |
|
|
|
auto param = ele->input(1)->cast<ParameterPtr>(); |
|
|
|
auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices); |
|
|
|
if (!manager->Replace(gatherv2, embedding_lookup)) { |
|
|
|
MS_LOG(EXCEPTION) << "Depend replace node failed"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::list<CNodePtr> orders = graph->GetOrderedCnodes(); |
|
|
|
CNodePtrList cnodes(orders.begin(), orders.end()); |
|
|
|
bool training = graph->has_flag("training"); |
|
|
|
auto param_cache_enable_set = FindParamCacheEnable(graph); |
|
|
|
if (param_cache_enable_set.empty()) { |
|
|
|
@@ -451,6 +669,12 @@ void AddCacheEmbedding(const FuncGraphPtr &graph) { |
|
|
|
if (!CheckHostCacheParamSize(param_cache_enable_set)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto unique_cache_enable = FindUniqueCacheEnable(cnodes); |
|
|
|
if (unique_cache_enable.empty()) { |
|
|
|
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable."; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto unique_node = unique_cache_enable[0]; |
|
|
|
if (training) { |
|
|
|
// If training, create cache parameters corresponding to the host params with is cache_enable. |
|
|
|
// Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr. |
|
|
|
@@ -460,75 +684,14 @@ void AddCacheEmbedding(const FuncGraphPtr &graph) { |
|
|
|
// flush miss values to cache params and write back old values to host params. |
|
|
|
// If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add |
|
|
|
// ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph. |
|
|
|
auto unique_cache_enable = FindUniqueCacheEnable(cnodes); |
|
|
|
if (unique_cache_enable.empty()) { |
|
|
|
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable."; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set); |
|
|
|
auto param_set = MapKeysToSet(cache_host_params_map); |
|
|
|
ReplaceCacheParams(graph, cache_host_params_map); |
|
|
|
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); |
|
|
|
auto unique_node = unique_cache_enable[0]; |
|
|
|
|
|
|
|
CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set); |
|
|
|
auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0); |
|
|
|
auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map); |
|
|
|
|
|
|
|
AnfNodePtrList map_cache_idx_node_outputs; |
|
|
|
CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs); |
|
|
|
|
|
|
|
if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), map_cache_idx_node_outputs[0])) { |
|
|
|
MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed"; |
|
|
|
} |
|
|
|
|
|
|
|
auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs); |
|
|
|
|
|
|
|
AnfNodePtr last_node = cnodes[cnodes_size - 1]; |
|
|
|
CNodePtr return_node; |
|
|
|
if (last_node->isa<CNode>()) { |
|
|
|
return_node = last_node->cast<CNodePtr>(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { |
|
|
|
MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; |
|
|
|
} |
|
|
|
if (return_node->inputs().size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; |
|
|
|
} |
|
|
|
AnfNodePtrList invalid_nodes; |
|
|
|
for (auto &ele : node_pair_list) { |
|
|
|
std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(), |
|
|
|
std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) { |
|
|
|
return CreateControlDepend(graph, ele.first, sparse_gatherv2); |
|
|
|
}); |
|
|
|
invalid_nodes.emplace_back(ele.second); |
|
|
|
} |
|
|
|
auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1)); |
|
|
|
if (!manager->Replace(return_node->input(1), depend_node)) { |
|
|
|
MS_LOG(EXCEPTION) << "Depend replace node failed"; |
|
|
|
} |
|
|
|
// If use pipe in training, create parameters for no ref param such as labels and MapCacheIdx output[0] and |
|
|
|
// Unique output[1], in each step, it will train the data from last step, so that can hide the time of Unique |
|
|
|
// and other cpu kernels. So in the first step, it's fake data. |
|
|
|
CacheEmbeddingForTrain(graph, is_pipe, cnodes, unique_node, param_cache_enable_set); |
|
|
|
} else { |
|
|
|
// If eval, Use EmbeddingLookup(CPU) op to replace GatherV2. |
|
|
|
// The network is the same as Host-Device mode. |
|
|
|
auto unique_cache_enable = FindUniqueCacheEnable(cnodes); |
|
|
|
if (unique_cache_enable.empty()) { |
|
|
|
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable."; |
|
|
|
return; |
|
|
|
} |
|
|
|
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); |
|
|
|
// replace GatherV2 to EmbeddingLookupCPU |
|
|
|
auto indices = unique_cache_enable[0]->input(1); |
|
|
|
auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set); |
|
|
|
for (auto &ele : sparse_gatherv2_with_cache) { |
|
|
|
auto anf_ele = ele->cast<AnfNodePtr>(); |
|
|
|
auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele); |
|
|
|
auto param = ele->input(1)->cast<ParameterPtr>(); |
|
|
|
auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices); |
|
|
|
if (!manager->Replace(gatherv2, embedding_lookup)) { |
|
|
|
MS_LOG(EXCEPTION) << "Depend replace node failed"; |
|
|
|
} |
|
|
|
} |
|
|
|
CacheEmbeddingForEval(graph, cnodes, unique_node, param_cache_enable_set); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace parallel |
|
|
|
|