diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc index 1428082e74..8a0c9ed812 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc @@ -48,32 +48,54 @@ void LookUpTableTask(const float *input_addr, const T *indices_addr, float *outp void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); + node_ = kernel_node; std::vector input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); if (input_shape.empty()) { MS_LOG(EXCEPTION) << "param must be at least 1D"; } first_dim_size_ = input_shape[0]; + outer_dim_size_ = 1; for (size_t i = 1; i < input_shape.size(); ++i) { outer_dim_size_ *= input_shape[i]; } + indices_lens_ = 1; std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); for (const auto &shape : indices_shape) { indices_lens_ *= shape; } + indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1); if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) { offset_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrOffset); } - indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1); } template void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) const { + const std::vector &outputs) { + if (node_ != nullptr) { + std::vector input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); + if (input_shape.empty()) { + MS_LOG(EXCEPTION) << "param must be at least 1D"; + } + first_dim_size_ = input_shape[0]; + outer_dim_size_ = 1; + for (size_t i = 1; i < input_shape.size(); ++i) { + outer_dim_size_ *= input_shape[i]; + } + + indices_lens_ = 1; + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); + for (const auto &shape : indices_shape) { + indices_lens_ *= shape; + } + } auto input_addr = reinterpret_cast(inputs[0]->addr); auto indices_addr = reinterpret_cast(inputs[1]->addr); auto output_addr = reinterpret_cast(outputs[0]->addr); - const size_t thread_num = 16; - std::thread threads[16]; + const size_t kMaxThreadNum = 16; + size_t thread_num = indices_lens_ / 10000 + 1; + thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; + std::thread threads[kMaxThreadNum]; size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; size_t i; size_t task_offset = 0; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h index b73fbf4829..b1639100da 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h @@ -32,8 +32,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; template - void LaunchKernel(const std::vector &inputs, - const std::vector &outputs) const; + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); protected: void CheckParam(const CNodePtr &kernel_node); @@ -42,6 +41,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel { size_t first_dim_size_{1}; size_t outer_dim_size_{1}; TypeId indices_data_type_{kNumberTypeInt32}; + CNodePtr node_ = nullptr; }; MS_REG_CPU_KERNEL( diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 7c9e5550d4..b475959046 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -233,6 +233,8 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr & const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index ad8fcdf0d6..35d1848ac3 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -453,6 +453,41 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr return std::make_shared(params->element(), std::make_shared(out_shape)); } +AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto params = CheckArg(op_name, args_spec_list, 0); + auto params_shp = params->shape(); + MS_EXCEPTION_IF_NULL(params); + MS_EXCEPTION_IF_NULL(params_shp); + auto params_shape = params_shp->shape(); + auto indices = CheckArg(op_name, args_spec_list, 1); + auto indices_shp = indices->shape(); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(indices_shp); + auto indices_shape = indices_shp->shape(); + auto indices_max_shape = indices_shp->max_shape(); + ShapeVector shape; + ShapeVector max_shape; + shape.insert(shape.end(), indices_shape.begin(), indices_shape.end()); + shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end()); + if (!indices_max_shape.empty()) { + max_shape.insert(max_shape.end(), indices_max_shape.begin(), indices_max_shape.end()); + max_shape.insert(max_shape.end(), params_shape.begin() + 1, params_shape.end()); + } else { + max_shape = shape; + } + ShapeVector min_shape; + for (size_t i = 0; i < max_shape.size(); ++i) { + min_shape.emplace_back(1); + } + + AbstractTensorPtr ret = + std::make_shared(params->element(), std::make_shared(shape, min_shape, max_shape)); + return ret; +} + AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string &op_name = primitive->name(); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 9ffcbca521..b285453d1b 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -54,6 +54,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, {prim::kPrimGatherV2, {InferImplGatherV2, true}}, {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, + {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2373552cc3..8adecb1ef0 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -4192,12 +4192,20 @@ class EmbeddingLookup(PrimitiveWithInfer): validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) params_shp = params['shape'] - if len(params_shp) != 2: - raise ValueError("The dimension of 'params' in EmbeddingLookup must be 2, but got %d." % len(params_shp)) out_shape = indices['shape'] + params_shp[1:] + if 'max_shape' in indices: + out_max_shape = indices['max_shape'] + params_shp[1:] + else: + out_max_shape = out_shape + if 'min_shape' in indices: + out_min_shape = indices['min_shape'] + params_shp[1:] + else: + out_min_shape = out_shape out = {'shape': out_shape, 'dtype': params['dtype'], - 'value': None} + 'value': None, + 'max_shape': out_max_shape, + 'min_shape': out_min_shape} return out