From: @fangzehua Reviewed-by: @stsuteng Signed-off-by: @stsutengtags/v1.1.0
| @@ -48,32 +48,54 @@ void LookUpTableTask(const float *input_addr, const T *indices_addr, float *outp | |||
| void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| node_ = kernel_node; | |||
| std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "param must be at least 1D"; | |||
| } | |||
| first_dim_size_ = input_shape[0]; | |||
| outer_dim_size_ = 1; | |||
| for (size_t i = 1; i < input_shape.size(); ++i) { | |||
| outer_dim_size_ *= input_shape[i]; | |||
| } | |||
| indices_lens_ = 1; | |||
| std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| for (const auto &shape : indices_shape) { | |||
| indices_lens_ *= shape; | |||
| } | |||
| indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1); | |||
| if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) { | |||
| offset_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrOffset); | |||
| } | |||
| indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1); | |||
| } | |||
| template <typename T> | |||
| void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) const { | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (node_ != nullptr) { | |||
| std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | |||
| if (input_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "param must be at least 1D"; | |||
| } | |||
| first_dim_size_ = input_shape[0]; | |||
| outer_dim_size_ = 1; | |||
| for (size_t i = 1; i < input_shape.size(); ++i) { | |||
| outer_dim_size_ *= input_shape[i]; | |||
| } | |||
| indices_lens_ = 1; | |||
| std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | |||
| for (const auto &shape : indices_shape) { | |||
| indices_lens_ *= shape; | |||
| } | |||
| } | |||
| auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | |||
| auto indices_addr = reinterpret_cast<T *>(inputs[1]->addr); | |||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | |||
| const size_t thread_num = 16; | |||
| std::thread threads[16]; | |||
| const size_t kMaxThreadNum = 16; | |||
| size_t thread_num = indices_lens_ / 10000 + 1; | |||
| thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; | |||
| std::thread threads[kMaxThreadNum]; | |||
| size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; | |||
| size_t i; | |||
| size_t task_offset = 0; | |||
| @@ -32,8 +32,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) const; | |||
| void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs); | |||
| protected: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| @@ -42,6 +41,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel { | |||
| size_t first_dim_size_{1}; | |||
| size_t outer_dim_size_{1}; | |||
| TypeId indices_data_type_{kNumberTypeInt32}; | |||
| CNodePtr node_ = nullptr; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| @@ -233,6 +233,8 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -453,6 +453,41 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr | |||
| return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape)); | |||
| } | |||
| AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto params_shp = params->shape(); | |||
| MS_EXCEPTION_IF_NULL(params); | |||
| MS_EXCEPTION_IF_NULL(params_shp); | |||
| auto params_shape = params_shp->shape(); | |||
| auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| auto indices_shp = indices->shape(); | |||
| MS_EXCEPTION_IF_NULL(indices); | |||
| MS_EXCEPTION_IF_NULL(indices_shp); | |||
| auto indices_shape = indices_shp->shape(); | |||
| auto indices_max_shape = indices_shp->max_shape(); | |||
| ShapeVector shape; | |||
| ShapeVector max_shape; | |||
| shape.insert(shape.end(), indices_shape.begin(), indices_shape.end()); | |||
| shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end()); | |||
| if (!indices_max_shape.empty()) { | |||
| max_shape.insert(max_shape.end(), indices_max_shape.begin(), indices_max_shape.end()); | |||
| max_shape.insert(max_shape.end(), params_shape.begin() + 1, params_shape.end()); | |||
| } else { | |||
| max_shape = shape; | |||
| } | |||
| ShapeVector min_shape; | |||
| for (size_t i = 0; i < max_shape.size(); ++i) { | |||
| min_shape.emplace_back(1); | |||
| } | |||
| AbstractTensorPtr ret = | |||
| std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string &op_name = primitive->name(); | |||
| @@ -54,6 +54,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | |||
| {prim::kPrimGatherV2, {InferImplGatherV2, true}}, | |||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | |||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, | |||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | |||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, | |||
| {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, | |||
| @@ -4192,12 +4192,20 @@ class EmbeddingLookup(PrimitiveWithInfer): | |||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||
| validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | |||
| params_shp = params['shape'] | |||
| if len(params_shp) != 2: | |||
| raise ValueError("The dimension of 'params' in EmbeddingLookup must be 2, but got %d." % len(params_shp)) | |||
| out_shape = indices['shape'] + params_shp[1:] | |||
| if 'max_shape' in indices: | |||
| out_max_shape = indices['max_shape'] + params_shp[1:] | |||
| else: | |||
| out_max_shape = out_shape | |||
| if 'min_shape' in indices: | |||
| out_min_shape = indices['min_shape'] + params_shp[1:] | |||
| else: | |||
| out_min_shape = out_shape | |||
| out = {'shape': out_shape, | |||
| 'dtype': params['dtype'], | |||
| 'value': None} | |||
| 'value': None, | |||
| 'max_shape': out_max_shape, | |||
| 'min_shape': out_min_shape} | |||
| return out | |||