|
|
|
@@ -26,12 +26,16 @@ using mindspore::schema::PrimitiveType_EmbeddingLookup; |
|
|
|
|
|
|
|
namespace mindspore::kernel { |
|
|
|
int EmbeddingLookupCPUKernel::Init() { |
|
|
|
if (context_->infer_shape_interrupt_ && !context_->running_) { |
|
|
|
SetNeedReInit(); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
embedding_lookup_parameter_ = reinterpret_cast<EmbeddingLookupParameter *>(opParameter); |
|
|
|
embedding_lookup_parameter_->thread_num = thread_count_; |
|
|
|
|
|
|
|
if (!InferShapeDone()) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
return ReSize(); |
|
|
|
} |
|
|
|
|
|
|
|
int EmbeddingLookupCPUKernel::ReSize() { |
|
|
|
embedding_lookup_parameter_->ids_size_ = inputs_.back()->ElementsNum(); |
|
|
|
|
|
|
|
embedding_lookup_parameter_->layer_size_ = 1; |
|
|
|
@@ -45,18 +49,34 @@ int EmbeddingLookupCPUKernel::Init() { |
|
|
|
embedding_lookup_parameter_->layer_num_ += inputs_[i]->shape()[0]; |
|
|
|
} |
|
|
|
|
|
|
|
input_addr_ = reinterpret_cast<float *>( |
|
|
|
std::malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_)); |
|
|
|
if (input_addr_ != nullptr) { |
|
|
|
free(input_addr_); |
|
|
|
} |
|
|
|
if (context_ != nullptr && context_->allocator != nullptr) { |
|
|
|
input_addr_ = reinterpret_cast<float *>(context_->allocator->Malloc( |
|
|
|
sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_)); |
|
|
|
} else { |
|
|
|
input_addr_ = reinterpret_cast<float *>( |
|
|
|
malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_)); |
|
|
|
} |
|
|
|
if (input_addr_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Create memory failed"; |
|
|
|
return mindspore::lite::RET_MEMORY_FAILED; |
|
|
|
MS_LOG(ERROR) << "Malloc buffer failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
embedding_lookup_parameter_->is_regulated_ = |
|
|
|
reinterpret_cast<bool *>(std::malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_)); |
|
|
|
if (embedding_lookup_parameter_->is_regulated_ != nullptr) { |
|
|
|
free(embedding_lookup_parameter_->is_regulated_); |
|
|
|
} |
|
|
|
if (context_ != nullptr && context_->allocator != nullptr) { |
|
|
|
embedding_lookup_parameter_->is_regulated_ = |
|
|
|
reinterpret_cast<bool *>(context_->allocator->Malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_)); |
|
|
|
} else { |
|
|
|
embedding_lookup_parameter_->is_regulated_ = |
|
|
|
reinterpret_cast<bool *>(malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_)); |
|
|
|
} |
|
|
|
if (embedding_lookup_parameter_->is_regulated_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Create memory failed"; |
|
|
|
return mindspore::lite::RET_MEMORY_FAILED; |
|
|
|
MS_LOG(ERROR) << "Malloc buffer failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
for (int i = 0; i < embedding_lookup_parameter_->layer_num_; ++i) { |
|
|
|
@@ -66,8 +86,6 @@ int EmbeddingLookupCPUKernel::Init() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int EmbeddingLookupCPUKernel::ReSize() { return RET_OK; } |
|
|
|
|
|
|
|
int EmbeddingLookupCPUKernel::DoExcute(int task_id) { |
|
|
|
int error_code = EmbeddingLookup(input_addr_, ids_addr_, output_addr_, embedding_lookup_parameter_, task_id); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
|