Browse Source

modify arm cpu op: embedding_lookup and elu

tags/v0.7.0-beta
tao_yunhao 5 years ago
parent
commit
3eee337f70
7 changed files with 64 additions and 30 deletions
  1. +13
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/elu.cc
  2. +32
    -14
      mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc
  3. +8
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h
  4. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.cc
  5. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.h
  6. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc
  7. +9
    -9
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h

+ 13
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/elu.cc View File

@@ -28,12 +28,18 @@ namespace mindspore::kernel {
int EluCPUKernel::Init() {
elu_parameter_ = reinterpret_cast<EluParameter *>(opParameter);
elu_parameter_->thread_num_ = thread_count_;

if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int EluCPUKernel::ReSize() {
elu_parameter_->in_size_ = inputs_.front()->ElementsNum();
return RET_OK;
}

int EluCPUKernel::ReSize() { return RET_OK; }

int EluCPUKernel::DoExcute(int task_id) { Elu(input_addr, output_addr, elu_parameter_, task_id); }

int EluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
@@ -47,6 +53,11 @@ int EluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
}

int EluCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
input_addr = reinterpret_cast<float *>(inputs_.front()->Data());
output_addr = reinterpret_cast<float *>(outputs_.front()->Data());



+ 32
- 14
mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc View File

@@ -26,12 +26,16 @@ using mindspore::schema::PrimitiveType_EmbeddingLookup;

namespace mindspore::kernel {
int EmbeddingLookupCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
SetNeedReInit();
return RET_OK;
}
embedding_lookup_parameter_ = reinterpret_cast<EmbeddingLookupParameter *>(opParameter);
embedding_lookup_parameter_->thread_num = thread_count_;

if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int EmbeddingLookupCPUKernel::ReSize() {
embedding_lookup_parameter_->ids_size_ = inputs_.back()->ElementsNum();

embedding_lookup_parameter_->layer_size_ = 1;
@@ -45,18 +49,34 @@ int EmbeddingLookupCPUKernel::Init() {
embedding_lookup_parameter_->layer_num_ += inputs_[i]->shape()[0];
}

input_addr_ = reinterpret_cast<float *>(
std::malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_));
if (input_addr_ != nullptr) {
free(input_addr_);
}
if (context_ != nullptr && context_->allocator != nullptr) {
input_addr_ = reinterpret_cast<float *>(context_->allocator->Malloc(
sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_));
} else {
input_addr_ = reinterpret_cast<float *>(
malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_));
}
if (input_addr_ == nullptr) {
MS_LOG(ERROR) << "Create memory failed";
return mindspore::lite::RET_MEMORY_FAILED;
MS_LOG(ERROR) << "Malloc buffer failed";
return RET_ERROR;
}

embedding_lookup_parameter_->is_regulated_ =
reinterpret_cast<bool *>(std::malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_));
if (embedding_lookup_parameter_->is_regulated_ != nullptr) {
free(embedding_lookup_parameter_->is_regulated_);
}
if (context_ != nullptr && context_->allocator != nullptr) {
embedding_lookup_parameter_->is_regulated_ =
reinterpret_cast<bool *>(context_->allocator->Malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_));
} else {
embedding_lookup_parameter_->is_regulated_ =
reinterpret_cast<bool *>(malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_));
}
if (embedding_lookup_parameter_->is_regulated_ == nullptr) {
MS_LOG(ERROR) << "Create memory failed";
return mindspore::lite::RET_MEMORY_FAILED;
MS_LOG(ERROR) << "Malloc buffer failed";
return RET_ERROR;
}

for (int i = 0; i < embedding_lookup_parameter_->layer_num_; ++i) {
@@ -66,8 +86,6 @@ int EmbeddingLookupCPUKernel::Init() {
return RET_OK;
}

int EmbeddingLookupCPUKernel::ReSize() { return RET_OK; }

int EmbeddingLookupCPUKernel::DoExcute(int task_id) {
int error_code = EmbeddingLookup(input_addr_, ids_addr_, output_addr_, embedding_lookup_parameter_, task_id);
if (error_code != RET_OK) {


+ 8
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h View File

@@ -28,7 +28,14 @@ class EmbeddingLookupCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {}
~EmbeddingLookupCPUKernel() override{};
~EmbeddingLookupCPUKernel() override {
if (input_addr_ != nullptr) {
free(input_addr_);
}
if (embedding_lookup_parameter_->is_regulated_ != nullptr) {
free(embedding_lookup_parameter_->is_regulated_);
}
};

int Init() override;
int ReSize() override;


+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.cc View File

@@ -15,7 +15,6 @@
*/

#include "src/runtime/kernel/arm/nnacl/fp32/elu.h"
#include <string.h>
#include "include/errorcode.h"
#include "src/runtime/kernel/arm/nnacl/errorcode.h"
#include "mindspore/core/utils/log_adapter.h"


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/elu.h View File

@@ -19,12 +19,12 @@

#include "src/runtime/kernel/arm/nnacl/op_base.h"

struct EluParameter {
typedef struct {
OpParameter op_parameter_;
float alpha_;
int thread_num_;
int in_size_;
};
} EluParameter;

int Elu(float *input_data, float *output_data, EluParameter *parameter, int task_id);



+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc View File

@@ -15,7 +15,6 @@
*/

#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h"
#include <string.h>
#include "include/errorcode.h"
#include "src/runtime/kernel/arm/nnacl/errorcode.h"
#include "mindspore/core/utils/log_adapter.h"


+ 9
- 9
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h View File

@@ -19,15 +19,15 @@

#include "src/runtime/kernel/arm/nnacl/op_base.h"

struct EmbeddingLookupParameter {
OpParameter op_parameter_;
bool *is_regulated_;
float max_norm_;
int ids_size_;
int layer_size_;
int layer_num_;
int thread_num;
};
typedef struct {
OpParameter op_parameter_;
bool *is_regulated_;
float max_norm_;
int ids_size_;
int layer_size_;
int layer_num_;
int thread_num;
} EmbeddingLookupParameter;

int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id);



Loading…
Cancel
Save