Browse Source

!8572 add embeddinglookup dynamic

From: @fangzehua
Reviewed-by: @stsuteng
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3a41d747ca
6 changed files with 77 additions and 9 deletions
  1. +26
    -4
      mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc
  2. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h
  3. +2
    -0
      mindspore/core/abstract/infer_functions.h
  4. +35
    -0
      mindspore/core/abstract/prim_arrays.cc
  5. +1
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  6. +11
    -3
      mindspore/ops/operations/array_ops.py

+ 26
- 4
mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc View File

@@ -48,32 +48,54 @@ void LookUpTableTask(const float *input_addr, const T *indices_addr, float *outp

void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
node_ = kernel_node;
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.empty()) {
MS_LOG(EXCEPTION) << "param must be at least 1D";
}
first_dim_size_ = input_shape[0];
outer_dim_size_ = 1;
for (size_t i = 1; i < input_shape.size(); ++i) {
outer_dim_size_ *= input_shape[i];
}
indices_lens_ = 1;
std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (const auto &shape : indices_shape) {
indices_lens_ *= shape;
}
indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) {
offset_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrOffset);
}
indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
}

template <typename T>
void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) const {
const std::vector<kernel::AddressPtr> &outputs) {
if (node_ != nullptr) {
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0);
if (input_shape.empty()) {
MS_LOG(EXCEPTION) << "param must be at least 1D";
}
first_dim_size_ = input_shape[0];
outer_dim_size_ = 1;
for (size_t i = 1; i < input_shape.size(); ++i) {
outer_dim_size_ *= input_shape[i];
}

indices_lens_ = 1;
std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1);
for (const auto &shape : indices_shape) {
indices_lens_ *= shape;
}
}
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto indices_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
const size_t thread_num = 16;
std::thread threads[16];
const size_t kMaxThreadNum = 16;
size_t thread_num = indices_lens_ / 10000 + 1;
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
std::thread threads[kMaxThreadNum];
size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num;
size_t i;
size_t task_offset = 0;


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h View File

@@ -32,8 +32,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) const;
void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);

protected:
void CheckParam(const CNodePtr &kernel_node);
@@ -42,6 +41,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
size_t first_dim_size_{1};
size_t outer_dim_size_{1};
TypeId indices_data_type_{kNumberTypeInt32};
CNodePtr node_ = nullptr;
};

MS_REG_CPU_KERNEL(


+ 2
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -233,6 +233,8 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 35
- 0
mindspore/core/abstract/prim_arrays.cc View File

@@ -453,6 +453,41 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape));
}

AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto params_shp = params->shape();
MS_EXCEPTION_IF_NULL(params);
MS_EXCEPTION_IF_NULL(params_shp);
auto params_shape = params_shp->shape();
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto indices_shp = indices->shape();
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(indices_shp);
auto indices_shape = indices_shp->shape();
auto indices_max_shape = indices_shp->max_shape();
ShapeVector shape;
ShapeVector max_shape;
shape.insert(shape.end(), indices_shape.begin(), indices_shape.end());
shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end());
if (!indices_max_shape.empty()) {
max_shape.insert(max_shape.end(), indices_max_shape.begin(), indices_max_shape.end());
max_shape.insert(max_shape.end(), params_shape.begin() + 1, params_shape.end());
} else {
max_shape = shape;
}
ShapeVector min_shape;
for (size_t i = 0; i < max_shape.size(); ++i) {
min_shape.emplace_back(1);
}

AbstractTensorPtr ret =
std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
return ret;
}

AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();


+ 1
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -54,6 +54,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
{prim::kPrimGatherV2, {InferImplGatherV2, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},


+ 11
- 3
mindspore/ops/operations/array_ops.py View File

@@ -4192,12 +4192,20 @@ class EmbeddingLookup(PrimitiveWithInfer):
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
params_shp = params['shape']
if len(params_shp) != 2:
raise ValueError("The dimension of 'params' in EmbeddingLookup must be 2, but got %d." % len(params_shp))
out_shape = indices['shape'] + params_shp[1:]
if 'max_shape' in indices:
out_max_shape = indices['max_shape'] + params_shp[1:]
else:
out_max_shape = out_shape
if 'min_shape' in indices:
out_min_shape = indices['min_shape'] + params_shp[1:]
else:
out_min_shape = out_shape
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
'value': None,
'max_shape': out_max_shape,
'min_shape': out_min_shape}
return out




Loading…
Cancel
Save