Browse Source

add embedding infer

tags/v1.2.0-rc1
fangzehua 5 years ago
parent
commit
dadbd54f0e
4 changed files with 13 additions and 8 deletions
  1. +3
    -3
      mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc
  2. +2
    -0
      mindspore/ccsrc/utils/utils.h
  3. +7
    -4
      mindspore/core/abstract/prim_arrays.cc
  4. +1
    -1
      mindspore/core/abstract/primitive_infer_map.cc

+ 3
- 3
mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc View File

@@ -362,7 +362,7 @@ void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input

AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) {
MS_EXCEPTION_IF_NULL(graph);
PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup;
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
emb_lookup_primitive->set_attr(kAttrOffset, MakeValue<int64_t>(0));
std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices};
@@ -373,7 +373,7 @@ AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, A
AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx,
AnfNodePtr miss_value) {
MS_EXCEPTION_IF_NULL(graph);
PrimitivePtr cache_swap_table_primitive = prim::kPrimCacheSwapTable;
PrimitivePtr cache_swap_table_primitive = std::make_shared<Primitive>(kCacheSwapTableOpName);
std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx,
miss_value};
auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes);
@@ -383,7 +383,7 @@ AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_ta
AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx,
AnfNodePtr old_value) {
MS_EXCEPTION_IF_NULL(graph);
PrimitivePtr update_cache_primitive = prim::kPrimUpdateCache;
PrimitivePtr update_cache_primitive = std::make_shared<Primitive>(kUpdateCacheOpName);
update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));

auto params_ori_shp = params->Shape();


+ 2
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -210,6 +210,8 @@ constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate";
constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate";
constexpr auto kPushOpName = "Push";
constexpr auto kPullOpName = "Pull";
constexpr auto kUpdateCacheOpName = "UpdateCache";
constexpr auto kCacheSwapTableOpName = "CacheSwapTable";
constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup";
constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy";
constexpr auto kGatherV2OpName = "Gather";


+ 7
- 4
mindspore/core/abstract/prim_arrays.cc View File

@@ -661,7 +661,6 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto params_shp = params->shape();
MS_EXCEPTION_IF_NULL(params);
@@ -673,8 +672,10 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit
MS_EXCEPTION_IF_NULL(indices_shp);
auto indices_shape = indices_shp->shape();
auto indices_max_shape = indices_shp->max_shape();
auto indices_min_shape = indices_shp->min_shape();
ShapeVector shape;
ShapeVector max_shape;
ShapeVector min_shape;
shape.insert(shape.end(), indices_shape.begin(), indices_shape.end());
shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end());
if (!indices_max_shape.empty()) {
@@ -683,9 +684,11 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit
} else {
max_shape = shape;
}
ShapeVector min_shape;
for (size_t i = 0; i < max_shape.size(); ++i) {
min_shape.emplace_back(1);
if (!indices_min_shape.empty()) {
min_shape.insert(min_shape.end(), indices_min_shape.begin(), indices_min_shape.end());
min_shape.insert(min_shape.end(), params_shape.begin() + 1, params_shape.end());
} else {
min_shape = shape;
}

AbstractTensorPtr ret =


+ 1
- 1
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -78,6 +78,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimUnique, {InferImplUnique, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
{prim::kPrimGather, {InferImplGatherV2, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
@@ -199,7 +200,6 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
{prim::kPrimLess, {InferImplLess, true}},
{prim::kPrimStack, {InferImplStack, true}},
{prim::kPrimPad, {InferImplPad, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimDiv, {InferImplDiv, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, true}},


Loading…
Cancel
Save