Browse Source

!11462 add pipe for cache embedding

From: @fangzehua
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
e02b6852cb
3 changed files with 241 additions and 79 deletions
  1. +232
    -69
      mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc
  2. +1
    -1
      mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.h
  3. +8
    -9
      mindspore/nn/layer/embedding.py

+ 232
- 69
mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc View File

@@ -18,6 +18,7 @@
#include <random>
#include <vector>
#include <list>
#include <queue>
#include <utility>
#include <memory>
#include <unordered_map>
@@ -33,6 +34,8 @@ namespace parallel {
using ParamMap = std::unordered_map<ParameterPtr, ParameterPtr>;
using ParamSet = std::unordered_set<ParameterPtr>;
using NodePairList = std::vector<std::pair<AnfNodePtr, AnfNodePtr>>;
using AnfMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
using AnfSet = std::unordered_set<AnfNodePtr>;

ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet &parameter_cache_enable_set) {
ParamMap cache_host_params_map;
@@ -408,6 +411,7 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param
if (sparse_gather_v2_with_cache.empty()) {
MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param.";
}

auto indices = sparse_gather_v2_with_cache[0]->input(2);
for (auto &ele : sparse_gather_v2_with_cache) {
if (ele->input(2) != indices) {
@@ -433,13 +437,227 @@ AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNo
return gatherv2_nodes[0];
}

void AddCacheEmbedding(const FuncGraphPtr &graph) {
AnfSet FindNoRefParams(const FuncGraphPtr &graph) {
AnfSet no_ref_params;
auto params = graph->parameters();
for (auto &anf_param : params) {
auto param = anf_param->cast<ParameterPtr>();
if (!param->has_default()) {
MS_LOG(INFO) << param->DebugString() << " has no default";
no_ref_params.insert(anf_param);
}
}
return no_ref_params;
}

void RemoveOriginParamFromSet(const CNodePtr &unique_node, AnfSet *no_ref_params) {
std::queue<CNodePtr> que;
que.push(unique_node);
while (!que.empty()) {
auto node = que.front();
que.pop();
auto node_inputs = node->inputs();
for (auto &input : node_inputs) {
if (input->isa<CNode>()) {
que.push(input->cast<CNodePtr>());
} else if (input->isa<Parameter>()) {
int num = no_ref_params->erase(input);
if (num > 0) {
MS_LOG(INFO) << "Erase unique_node input from set success.";
return;
}
}
}
}
MS_LOG(EXCEPTION) << "Can not find any parameter that use by Unique.";
}

AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &ori_input, const std::string &name) {
auto ori_input_type = ori_input->Type();
auto ori_input_element_type = ori_input_type->cast<mindspore::TensorTypePtr>()->element();
auto ori_input_type_id = ori_input_element_type->type_id();
auto ori_input_shp = ori_input->Shape();
auto input_shp = ori_input_shp->cast<abstract::ShapePtr>();
auto input_shape = input_shp->shape();
auto new_tensor = std::make_shared<tensor::Tensor>(ori_input_type_id, input_shape);
ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
auto new_param_name = name + "_pipe";
new_param_info->set_name(new_param_name);
new_tensor->set_param_info(new_param_info);
auto new_param = graph->AddWeightParameter(new_param_name);
new_param->set_default_param(MakeValue(new_tensor));
auto abs_tensor = new_tensor->ToAbstract();
new_param->set_abstract(abs_tensor);
return new_param->cast<AnfNodePtr>();
}

AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) {
AnfMap no_ref_pipe_param_map;
for (auto &param : no_ref_params) {
auto ori_param = param->cast<ParameterPtr>();
auto ori_name = ori_param->name();
auto new_param = CreateOutputNodeParam(graph, param, ori_name);
no_ref_pipe_param_map[param] = new_param;
}
return no_ref_pipe_param_map;
}

AnfNodePtr CreateAssign(const FuncGraphPtr &graph, const AnfNodePtr &res_param, const AnfNodePtr &src_param,
bool is_dynamic = false) {
auto assign_prim = prim::kPrimAssign;
if (is_dynamic) {
assign_prim = prim::kPrimDynamicAssign;
assign_prim->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
}
std::vector<AnfNodePtr> assign_nodes{NewValueNode(assign_prim), res_param, src_param};
auto assign_status = graph->NewCNode(assign_nodes);
return assign_status;
}

AnfNodePtr FindCNodeOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t index) {
auto manager = graph->manager();
auto node_users = manager->node_users()[node];
for (auto &node_user : node_users) {
if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
auto cnode = node_user.first->cast<CNodePtr>();
auto node_index = cnode->input(2);
if (node_index->isa<ValueNode>()) {
auto value_node = node_index->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto item_idx = GetValue<int64_t>(value_node->value());
if (item_idx == index) {
return node_user.first;
}
}
}
}
MS_LOG(EXCEPTION) << "Can't not find " << node->DebugString() << ", outputs:" << index;
}

AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map,
const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx,
const AnfNodePtr &sparse_gatherv2_indices) {
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_users = manager->node_users();
AnfNodePtrList control_depend_list;
// add other no ref pipe param and unique index dense
for (auto &ele : no_ref_pipe_param_map) {
auto user_set = node_users[ele.first];
auto assign_status = CreateAssign(graph, ele.second, ele.first);
for (auto user_node : user_set) {
auto control_depend = CreateControlDepend(graph, user_node.first, assign_status);
control_depend_list.emplace_back(control_depend);
}
if (!manager->Replace(ele.first, ele.second)) {
MS_LOG(EXCEPTION) << "pipe param: " << ele.first->DebugString() << ", replace node failed.";
}
}

// add cache idx param
auto dynamic_assgin_status = CreateAssign(graph, cache_idx_param, cache_idx, true);
auto indices_user_set = node_users[sparse_gatherv2_indices];
for (auto &user_node : indices_user_set) {
auto control_depend = CreateControlDepend(graph, user_node.first, dynamic_assgin_status);
control_depend_list.emplace_back(control_depend);
}
if (!manager->Replace(sparse_gatherv2_indices, cache_idx_param)) {
MS_LOG(EXCEPTION) << "cache idx param: " << cache_idx_param->DebugString() << ", replace node failed.";
}
return control_depend_list;
}

void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNodePtrList &cnodes,
const CNodePtr &unique_node, const ParamSet &param_cache_enable_set) {
MS_EXCEPTION_IF_NULL(graph);
std::list<CNodePtr> orders = graph->GetOrderedCnodes();
CNodePtrList cnodes(orders.begin(), orders.end());
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
size_t cnodes_size = cnodes.size();
auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set);
auto param_set = MapKeysToSet(cache_host_params_map);
ReplaceCacheParams(graph, cache_host_params_map);
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
MS_LOG(INFO) << "Graph is set cache enable.";

CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set);
auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0);
auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map);

AnfNodePtrList map_cache_idx_node_outputs;
CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs);

auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs);
AnfNodePtrList invalid_nodes;
auto cache_idx = map_cache_idx_node_outputs[0];
if (!is_pipe) {
if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), cache_idx)) {
MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed";
}
for (auto &ele : node_pair_list) {
std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(),
std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) {
return CreateControlDepend(graph, ele.first, sparse_gatherv2);
});
invalid_nodes.emplace_back(ele.second);
}
} else {
auto cache_idx_param = CreateOutputNodeParam(graph, unique_node->input(1), std::string("cache_idx"));
auto unique_index_reverse = FindCNodeOutput(graph, unique_node, 1);
auto unique_index_param = CreateOutputNodeParam(graph, unique_index_reverse, std::string("index_dense"));
auto no_ref_params = FindNoRefParams(graph);
RemoveOriginParamFromSet(unique_node, &no_ref_params);
auto no_ref_param_map = CreateOtherPipeParams(graph, no_ref_params);
no_ref_param_map[unique_index_reverse] = unique_index_param;
auto control_depend_list = ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx,
sparse_gatherv2_with_cache[0]->input(2));
std::copy(control_depend_list.begin(), control_depend_list.end(), std::back_inserter(invalid_nodes));
std::transform(node_pair_list.begin(), node_pair_list.end(), std::back_inserter(invalid_nodes),
[](const std::pair<AnfNodePtr, AnfNodePtr> &pair) { return pair.second; });
}
AnfNodePtr last_node = cnodes[cnodes_size - 1];
CNodePtr return_node;
if (last_node->isa<CNode>()) {
return_node = last_node->cast<CNodePtr>();
}
MS_EXCEPTION_IF_NULL(return_node);
if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) {
MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode.";
}
if (return_node->inputs().size() < 2) {
MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2.";
}

auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1));
if (!manager->Replace(return_node->input(1), depend_node)) {
MS_LOG(EXCEPTION) << "Depend replace node failed";
}
}

void CacheEmbeddingForEval(const FuncGraphPtr &graph, const CNodePtrList &cnodes, const CNodePtr &unique_node,
const ParamSet &param_cache_enable_set) {
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
MS_LOG(INFO) << "Graph is set cache enable.";
// replace GatherV2 to EmbeddingLookupCPU
auto indices = unique_node->input(1);
auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set);
for (auto &ele : sparse_gatherv2_with_cache) {
auto anf_ele = ele->cast<AnfNodePtr>();
auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele);
auto param = ele->input(1)->cast<ParameterPtr>();
auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices);
if (!manager->Replace(gatherv2, embedding_lookup)) {
MS_LOG(EXCEPTION) << "Depend replace node failed";
}
}
}

void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe) {
MS_EXCEPTION_IF_NULL(graph);
std::list<CNodePtr> orders = graph->GetOrderedCnodes();
CNodePtrList cnodes(orders.begin(), orders.end());
bool training = graph->has_flag("training");
auto param_cache_enable_set = FindParamCacheEnable(graph);
if (param_cache_enable_set.empty()) {
@@ -451,6 +669,12 @@ void AddCacheEmbedding(const FuncGraphPtr &graph) {
if (!CheckHostCacheParamSize(param_cache_enable_set)) {
return;
}
auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
if (unique_cache_enable.empty()) {
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
return;
}
auto unique_node = unique_cache_enable[0];
if (training) {
// If training, create cache parameters corresponding to the host params with is cache_enable.
// Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr.
@@ -460,75 +684,14 @@ void AddCacheEmbedding(const FuncGraphPtr &graph) {
// flush miss values to cache params and write back old values to host params.
// If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add
// ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph.
auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
if (unique_cache_enable.empty()) {
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
return;
}
auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set);
auto param_set = MapKeysToSet(cache_host_params_map);
ReplaceCacheParams(graph, cache_host_params_map);
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
auto unique_node = unique_cache_enable[0];

CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set);
auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0);
auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map);

AnfNodePtrList map_cache_idx_node_outputs;
CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs);

if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), map_cache_idx_node_outputs[0])) {
MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed";
}

auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs);

AnfNodePtr last_node = cnodes[cnodes_size - 1];
CNodePtr return_node;
if (last_node->isa<CNode>()) {
return_node = last_node->cast<CNodePtr>();
}
MS_EXCEPTION_IF_NULL(return_node);
if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) {
MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode.";
}
if (return_node->inputs().size() < 2) {
MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2.";
}
AnfNodePtrList invalid_nodes;
for (auto &ele : node_pair_list) {
std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(),
std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) {
return CreateControlDepend(graph, ele.first, sparse_gatherv2);
});
invalid_nodes.emplace_back(ele.second);
}
auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1));
if (!manager->Replace(return_node->input(1), depend_node)) {
MS_LOG(EXCEPTION) << "Depend replace node failed";
}
// If use pipe in training, create parameters for no ref param such as labels and MapCacheIdx output[0] and
// Unique output[1], in each step, it will train the data from last step, so that can hide the time of Unique
// and other cpu kernels. So in the first step, it's fake data.
CacheEmbeddingForTrain(graph, is_pipe, cnodes, unique_node, param_cache_enable_set);
} else {
// If eval, Use EmbeddingLookup(CPU) op to replace GatherV2.
// The network is the same as Host-Device mode.
auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
if (unique_cache_enable.empty()) {
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
return;
}
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
// replace GatherV2 to EmbeddingLookupCPU
auto indices = unique_cache_enable[0]->input(1);
auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set);
for (auto &ele : sparse_gatherv2_with_cache) {
auto anf_ele = ele->cast<AnfNodePtr>();
auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele);
auto param = ele->input(1)->cast<ParameterPtr>();
auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices);
if (!manager->Replace(gatherv2, embedding_lookup)) {
MS_LOG(EXCEPTION) << "Depend replace node failed";
}
}
CacheEmbeddingForEval(graph, cnodes, unique_node, param_cache_enable_set);
}
}
} // namespace parallel


+ 1
- 1
mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.h View File

@@ -22,7 +22,7 @@
namespace mindspore {
namespace parallel {
// Automatically adding control depend based on effect order and side effect analysis.
void AddCacheEmbedding(const FuncGraphPtr &graph);
void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe = false);
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_

+ 8
- 9
mindspore/nn/layer/embedding.py View File

@@ -21,7 +21,7 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode
from mindspore.context import ParallelMode, get_context
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context
from mindspore._checkparam import Rel
@@ -278,7 +278,7 @@ class EmbeddingLookup(Cell):
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
+ str(slice_mode))
if self.cache_enable and not enable_ps:
if is_auto_parallel:
if parallel_mode != ParallelMode.STAND_ALONE:
raise ValueError("parallel mode haven't supported cache enable yet.")
self._set_cache_enable()
self.embedding_table.unique = self.forward_unique
@@ -288,15 +288,14 @@ class EmbeddingLookup(Cell):
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)

def _set_cache_enable(self):
"""EmbeddingLookup cache check for not ps env."""
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
if self.target != 'DEVICE':
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
"so it will be ignored.")
return
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.")
if not self.sparse:
logger.warning("The configuration of 'vocab_cache_size' is valid only 'sparse' is true, "
"so it will be ignored.")
return
raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.")
if get_context("device_target") != 'Ascend':
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.")

logger.info("EmbeddingLookup cache enable takes effect.")
self.forward_unique = True
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')


Loading…
Cancel
Save