Merge pull request !4033 from 陶云浩/litetags/v0.7.0-beta
| @@ -727,8 +727,7 @@ table AddN { | |||
| table EmbeddingLookup { | |||
| ids: [int]; | |||
| maxNorm: float; | |||
| maxNorm: float = 0.0; | |||
| } | |||
| table EmbeddingLookupSparse { | |||
| @@ -216,6 +216,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { | |||
| return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_QuantDTypeCast: | |||
| return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_EmbeddingLookup: | |||
| return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(srcPrim)); | |||
| default: | |||
| break; | |||
| } | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/ops.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/ir/tensor.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore::lite { | |||
| int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| if (inputs_.size() < kDoubleNum) { | |||
| MS_LOG(ERROR) << "Embedding Lookup should have at least two inputs"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| if (outputs_.size() != kSingleNum) { | |||
| MS_LOG(ERROR) << "Embedding Lookup should have one outputs"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| auto params_ = inputs_.front(); | |||
| MS_ASSERT(params_ != nullptr); | |||
| auto ids = inputs_.back(); | |||
| MS_ASSERT(ids != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| auto embedding_shape = params_->shape(); | |||
| embedding_shape.erase(embedding_shape.begin()); | |||
| std::vector<int> output_shape(ids->shape()); | |||
| for (size_t i = 0; i < embedding_shape.size(); ++i) { | |||
| output_shape.push_back(embedding_shape.at(i)); | |||
| } | |||
| for (int i = 1; i < inputs_.size() - 1; ++i) { | |||
| auto embedding_shape_t = inputs_.at(i)->shape(); | |||
| embedding_shape_t.erase(embedding_shape_t.begin()); | |||
| if (embedding_shape_t != embedding_shape) { | |||
| MS_LOG(ERROR) << "The embedded layers should have the same shape"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(params_->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -141,6 +141,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) { | |||
| return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(primitive)); | |||
| case schema::PrimitiveType_MatMul: | |||
| return new lite::MatMul(const_cast<schema::Primitive *>(primitive)); | |||
| case schema::PrimitiveType_EmbeddingLookup: | |||
| return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(primitive)); | |||
| default: | |||
| break; | |||
| } | |||
| @@ -778,6 +778,13 @@ class Lstm : public Primitive { | |||
| const schema::Lstm *GetAttribute() const { return this->primitive->value_as_Lstm(); } | |||
| int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override; | |||
| }; | |||
| class EmbeddingLookup : public Primitive { | |||
| public: | |||
| explicit EmbeddingLookup(schema::Primitive *primitive) : Primitive(primitive) {} | |||
| const schema::EmbeddingLookup *GetAttribute() const { return this->primitive->value_as_EmbeddingLookup(); } | |||
| int InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_OPS_OPS_H_ | |||
| @@ -69,6 +69,7 @@ | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h" | |||
| #include "src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/lstm.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" | |||
| namespace mindspore::kernel { | |||
| OpParameter *PopulateBatchNorm(const lite::Primitive *primitive) { | |||
| @@ -1209,6 +1210,23 @@ OpParameter *PopulateLstmParameter(const lite::Primitive *primitive) { | |||
| return reinterpret_cast<OpParameter *>(lstm_param); | |||
| } | |||
| OpParameter *PopulateEmbeddingLookupParameter(const lite::Primitive *primitive) { | |||
| EmbeddingLookupParameter *embedding_lookup_parameter = new (std::nothrow) EmbeddingLookupParameter(); | |||
| if (embedding_lookup_parameter == nullptr) { | |||
| MS_LOG(ERROR) << "new EmbeddingLookupParameter failed"; | |||
| return nullptr; | |||
| } | |||
| embedding_lookup_parameter->op_parameter_.type_ = primitive->Type(); | |||
| auto param = primitive->Value()->value_as_EmbeddingLookup(); | |||
| embedding_lookup_parameter->max_norm_ = param->maxNorm(); | |||
| if (embedding_lookup_parameter->max_norm_ < 0) { | |||
| MS_LOG(ERROR) << "Embedding lookup max norm should be positive number, got " | |||
| << embedding_lookup_parameter->max_norm_; | |||
| return nullptr; | |||
| } | |||
| return reinterpret_cast<OpParameter *>(embedding_lookup_parameter); | |||
| } | |||
| PopulateParameterRegistry::PopulateParameterRegistry() { | |||
| populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter; | |||
| @@ -1286,6 +1304,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||
| populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_Lstm] = PopulateLstmParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_EmbeddingLookup] = PopulateEmbeddingLookupParameter; | |||
| } | |||
| PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { | |||
| @@ -137,12 +137,12 @@ class ArithmeticCPUKernel : public LiteKernel { | |||
| arithmetic_broadcast_run_ = BroadcastNotEqual; | |||
| break; | |||
| case PrimitiveType_Less: | |||
| arithmetic_run_ = ElementEqual; | |||
| arithmetic_broadcast_run_ = BroadcastEqual; | |||
| arithmetic_run_ = ElementLess; | |||
| arithmetic_broadcast_run_ = BroadcastLess; | |||
| break; | |||
| case PrimitiveType_LessEqual: | |||
| arithmetic_run_ = ElementNotEqual; | |||
| arithmetic_broadcast_run_ = BroadcastNotEqual; | |||
| arithmetic_run_ = ElementLessEqual; | |||
| arithmetic_broadcast_run_ = BroadcastLessEqual; | |||
| break; | |||
| case PrimitiveType_Greater: | |||
| arithmetic_run_ = ElementGreater; | |||
| @@ -0,0 +1,130 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32/embedding_lookup.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_EmbeddingLookup; | |||
| namespace mindspore::kernel { | |||
| int EmbeddingLookupCPUKernel::Init() { | |||
| embedding_lookup_parameter_ = reinterpret_cast<EmbeddingLookupParameter *>(opParameter); | |||
| embedding_lookup_parameter_->thread_num = thread_count_; | |||
| embedding_lookup_parameter_->ids_size_ = inputs_.back()->ElementsNum(); | |||
| embedding_lookup_parameter_->layer_size_ = 1; | |||
| auto in_shape = inputs_.front()->shape(); | |||
| for (int i = 1; i < in_shape.size(); ++i) { | |||
| embedding_lookup_parameter_->layer_size_ *= in_shape[i]; | |||
| } | |||
| embedding_lookup_parameter_->layer_num_ = 0; | |||
| for (int i = 0; i < inputs_.size() - 1; ++i) { | |||
| embedding_lookup_parameter_->layer_num_ += inputs_[i]->shape()[0]; | |||
| } | |||
| input_addr_ = reinterpret_cast<float *>( | |||
| std::malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_)); | |||
| if (input_addr_ == nullptr) { | |||
| MS_LOG(ERROR) << "Create memory failed"; | |||
| return mindspore::lite::RET_MEMORY_FAILED; | |||
| } | |||
| embedding_lookup_parameter_->is_regulated_ = | |||
| reinterpret_cast<bool *>(std::malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_)); | |||
| if (embedding_lookup_parameter_->is_regulated_ == nullptr) { | |||
| MS_LOG(ERROR) << "Create memory failed"; | |||
| return mindspore::lite::RET_MEMORY_FAILED; | |||
| } | |||
| for (int i = 0; i < embedding_lookup_parameter_->layer_num_; ++i) { | |||
| embedding_lookup_parameter_->is_regulated_[i] = embedding_lookup_parameter_->max_norm_ == 0; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int EmbeddingLookupCPUKernel::ReSize() { return RET_OK; } | |||
| int EmbeddingLookupCPUKernel::DoExcute(int task_id) { | |||
| int error_code = EmbeddingLookup(input_addr_, ids_addr_, output_addr_, embedding_lookup_parameter_, task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "embedding lookup error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int EmbeddingLookupRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto EmbeddingLookupData = reinterpret_cast<EmbeddingLookupCPUKernel *>(cdata); | |||
| auto ret = EmbeddingLookupData->DoExcute(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "EmbeddingLookupRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int EmbeddingLookupCPUKernel::Run() { | |||
| int dest_loc = 0; | |||
| for (int i = 0; i < inputs_.size() - 1; i++) { | |||
| auto input_t = reinterpret_cast<float *>(inputs_.at(i)->Data()); | |||
| memcpy(input_addr_ + dest_loc, input_t, sizeof(float) * inputs_.at(i)->ElementsNum()); | |||
| dest_loc += inputs_.at(i)->ElementsNum(); | |||
| } | |||
| output_addr_ = reinterpret_cast<float *>(outputs_.front()->Data()); | |||
| ids_addr_ = reinterpret_cast<int *>(inputs_.back()->Data()); | |||
| auto ret = LiteBackendParallelLaunch(EmbeddingLookupRun, this, embedding_lookup_parameter_->thread_num); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "EmbeddingLookup error: error_code[" << ret << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuEmbeddingLookupFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *parameter, const lite::Context *ctx, | |||
| const KernelKey &desc) { | |||
| if (parameter == nullptr || ctx == nullptr) { | |||
| MS_LOG(ERROR) << "parameter or ctx is nullptr"; | |||
| return nullptr; | |||
| } | |||
| MS_ASSERT(desc.type == PrimitiveType_EmbeddingLookup); | |||
| auto *kernel = new (std::nothrow) EmbeddingLookupCPUKernel(parameter, inputs, outputs, ctx); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create Kernel failed, name: " << parameter->name_; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init Kernel failed, name: " << parameter->name_ | |||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_EmbeddingLookup, CpuEmbeddingLookupFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" | |||
| namespace mindspore::kernel { | |||
| class EmbeddingLookupCPUKernel : public LiteKernel { | |||
| public: | |||
| explicit EmbeddingLookupCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx) | |||
| : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} | |||
| ~EmbeddingLookupCPUKernel() override{}; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int DoExcute(int task_id); | |||
| protected: | |||
| int thread_count_; | |||
| const lite::Context *ctx_; | |||
| EmbeddingLookupParameter *embedding_lookup_parameter_; | |||
| private: | |||
| float *input_addr_; | |||
| float *output_addr_; | |||
| int *ids_addr_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_ | |||
| @@ -81,10 +81,10 @@ int ArithmeticInt8CPUKernel::Init() { | |||
| arithmetic_run_ = ElementNotEqual; | |||
| break; | |||
| case PrimitiveType_Less: | |||
| arithmetic_run_ = ElementEqual; | |||
| arithmetic_run_ = ElementLess; | |||
| break; | |||
| case PrimitiveType_LessEqual: | |||
| arithmetic_run_ = ElementNotEqual; | |||
| arithmetic_run_ = ElementLessEqual; | |||
| break; | |||
| case PrimitiveType_Greater: | |||
| arithmetic_run_ = ElementGreater; | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" | |||
| #include <string.h> | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/kernel/arm/nnacl/errorcode.h" | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| void l2_regulate(float *data, int size, float max_norm) { | |||
| float sum = 0; | |||
| for (int i = 0; i < size; ++i) { | |||
| sum += data[i]; | |||
| } | |||
| if (sum != 0) { | |||
| for (int i = 0; i < size; ++i) { | |||
| data[i] *= max_norm / sum; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| int CopyData(float *input_data, int *ids, float *output_data, int num, EmbeddingLookupParameter *parameter) { | |||
| if (ids[num] >= parameter->layer_num_ || ids[num] < 0) { | |||
| MS_LOG(ERROR) << "Embedding lookup index out of range"; | |||
| return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; | |||
| } | |||
| float *out_data = output_data + num * parameter->layer_size_; | |||
| float *in_data = input_data + ids[num] * parameter->layer_size_; | |||
| if (!parameter->is_regulated_[ids[num]]) { | |||
| l2_regulate(in_data, parameter->layer_size_, parameter->max_norm_); | |||
| parameter->is_regulated_[ids[num]] = true; | |||
| } | |||
| memcpy(out_data, in_data, sizeof(float) * parameter->layer_size_); | |||
| return NNACL_OK; | |||
| } | |||
| int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id) { | |||
| for (size_t i = task_id; i < parameter->ids_size_; i += parameter->thread_num) { | |||
| int ret = CopyData(input_data, ids, output_data, i, parameter); | |||
| if (ret != NNACL_OK) { | |||
| return ret; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_ | |||
| #include "src/runtime/kernel/arm/nnacl/op_base.h" | |||
| struct EmbeddingLookupParameter { | |||
| OpParameter op_parameter_; | |||
| bool *is_regulated_; | |||
| float max_norm_; | |||
| int ids_size_; | |||
| int layer_size_; | |||
| int layer_num_; | |||
| int thread_num; | |||
| }; | |||
| int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id); | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_ | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include "src/runtime/kernel/arm/fp32/embedding_lookup.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "common/common_test.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| using mindspore::lite::tensor::Tensor; | |||
| class TestEmbeddingLookupFp32 : public mindspore::Common { | |||
| public: | |||
| TestEmbeddingLookupFp32() {} | |||
| }; | |||
| void ElTestInit(std::vector<Tensor *> *inputs_, std::vector<Tensor *> *outputs_, | |||
| EmbeddingLookupParameter *embedding_lookup_param) { | |||
| Tensor *in_t_first = new Tensor(kNumberTypeFloat32, {6, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1)); | |||
| in_t_first->MallocData(); | |||
| float in_first[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; | |||
| memcpy(in_t_first->Data(), in_first, sizeof(float) * in_t_first->ElementsNum()); | |||
| inputs_->push_back(in_t_first); | |||
| Tensor *in_t_second = new Tensor(kNumberTypeFloat32, {4, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1)); | |||
| in_t_second->MallocData(); | |||
| float in_second[] = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}; | |||
| memcpy(in_t_second->Data(), in_second, sizeof(float) * in_t_second->ElementsNum()); | |||
| inputs_->push_back(in_t_second); | |||
| Tensor *ids_t = new Tensor(kNumberTypeFloat32, {2, 3}, schema::Format_NHWC, static_cast<schema::NodeType>(1)); | |||
| ids_t->MallocData(); | |||
| int ids[] = {1, 9, 2, 4, 6, 7}; | |||
| memcpy(ids_t->Data(), ids, sizeof(int) * ids_t->ElementsNum()); | |||
| inputs_->push_back(ids_t); | |||
| Tensor *outputs_t = new Tensor(kNumberTypeInt32, {2, 3, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1)); | |||
| outputs_t->MallocData(); | |||
| outputs_->push_back(outputs_t); | |||
| embedding_lookup_param->max_norm_ = 1; | |||
| } | |||
| TEST_F(TestEmbeddingLookupFp32, ElTest) { | |||
| std::vector<Tensor *> inputs_; | |||
| std::vector<Tensor *> outputs_; | |||
| auto embedding_lookup_param_ = new EmbeddingLookupParameter(); | |||
| ElTestInit(&inputs_, &outputs_, embedding_lookup_param_); | |||
| lite::Context *ctx = new lite::Context; | |||
| ctx->thread_num_ = 2; | |||
| kernel::EmbeddingLookupCPUKernel *el = new kernel::EmbeddingLookupCPUKernel( | |||
| reinterpret_cast<OpParameter *>(embedding_lookup_param_), inputs_, outputs_, ctx); | |||
| el->Init(); | |||
| el->Run(); | |||
| std::cout << "output shape:" << std::endl; | |||
| for (int i = 0; i < outputs_.front()->shape().size(); ++i) { | |||
| std::cout << outputs_.front()->shape()[i] << ' '; | |||
| } | |||
| std::cout << std::endl; | |||
| float *out = reinterpret_cast<float *>(outputs_.front()->Data()); | |||
| for (int i = 0; i < outputs_.front()->ElementsNum(); ++i) { | |||
| std::cout << out[i] << ' '; | |||
| } | |||
| std::cout << std::endl; | |||
| } | |||
| } // namespace mindspore | |||