| @@ -32,7 +32,7 @@ | |||||
| #include "src/runtime/kernel/arm/opclib/conv_parameter.h" | #include "src/runtime/kernel/arm/opclib/conv_parameter.h" | ||||
| #include "src/runtime/kernel/arm/opclib/fp32/pooling.h" | #include "src/runtime/kernel/arm/opclib/fp32/pooling.h" | ||||
| #include "src/runtime/kernel/arm/opclib/matmul.h" | #include "src/runtime/kernel/arm/opclib/matmul.h" | ||||
| #include "src/runtime/kernel/arm/opclib/fp32/softmax.h" | |||||
| #include "src/runtime/kernel/arm/opclib/softmax_parameter.h" | |||||
| #include "src/runtime/kernel/arm/opclib/tile.h" | #include "src/runtime/kernel/arm/opclib/tile.h" | ||||
| #include "src/runtime/kernel/arm/opclib/topk.h" | #include "src/runtime/kernel/arm/opclib/topk.h" | ||||
| #include "src/runtime/kernel/arm/opclib/fp32/reduce.h" | #include "src/runtime/kernel/arm/opclib/fp32/reduce.h" | ||||
| @@ -0,0 +1,103 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/base/softmax_base.h" | |||||
| #include <vector> | |||||
| #include "src/runtime/kernel/arm/int8/softmax_int8.h" | |||||
| #include "src/runtime/kernel/arm/fp32/softmax.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/softmax.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_factory.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::lite::RET_NULL_PTR; | |||||
| using mindspore::schema::PrimitiveType_SoftMax; | |||||
| namespace mindspore::kernel { | |||||
| int SoftmaxBaseCPUKernel::Init() { | |||||
| if (softmax_param_ == nullptr) { | |||||
| MS_LOG(ERROR) << "SoftmaxParameter nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto input_tensor = inputs_.front(); | |||||
| auto in_shape = input_tensor->shape(); | |||||
| auto in_dims = in_shape.size(); | |||||
| int ele_size = 1; | |||||
| softmax_param_->n_dim_ = in_dims; | |||||
| for (size_t i = 0; i < in_dims; i++) { | |||||
| softmax_param_->input_shape_[i] = in_shape[i]; | |||||
| ele_size *= in_shape[i]; | |||||
| } | |||||
| softmax_param_->element_size_ = ele_size; | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuSoftmaxInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| if (opParameter == nullptr) { | |||||
| MS_LOG(ERROR) << "Input opParameter is nullptr!"; | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); | |||||
| auto *kernel = new (std::nothrow) SoftmaxInt8CPUKernel(opParameter, inputs, outputs, ctx); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| delete kernel; | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| kernel::LiteKernel *CpuSoftmaxFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| if (opParameter == nullptr) { | |||||
| MS_LOG(ERROR) << "Input opParameter is nullptr!"; | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); | |||||
| auto *kernel = new (std::nothrow) SoftmaxCPUKernel(opParameter, inputs, outputs, ctx); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| delete kernel; | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_SoftMax, CpuSoftmaxInt8KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftMax, CpuSoftmaxFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SOFTMAX_BASE_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SOFTMAX_BASE_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/runtime/kernel/arm/opclib/softmax_parameter.h" | |||||
| namespace mindspore::kernel { | |||||
| class SoftmaxBaseCPUKernel : public LiteKernel { | |||||
| public: | |||||
| SoftmaxBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx) | |||||
| : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { | |||||
| opParameter->thread_num_ = ctx->thread_num_; | |||||
| softmax_param_ = reinterpret_cast<SoftmaxParameter *>(opParameter); | |||||
| } | |||||
| ~SoftmaxBaseCPUKernel() = default; | |||||
| int Init() override; | |||||
| int ReSize() override { return 0; } | |||||
| int Run() override { return 0; } | |||||
| protected: | |||||
| int thread_count_; | |||||
| const lite::Context *ctx_; | |||||
| SoftmaxParameter *softmax_param_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SOFTMAX_BASE_H_ | |||||
| @@ -29,13 +29,7 @@ using mindspore::schema::PrimitiveType_LocalResponseNormalization; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int LocalResponseNormCPUKernel::Init() { | |||||
| depth_radius_ = (reinterpret_cast<LocalResponseNormParameter *>(opParameter))->depth_radius_; | |||||
| bias_ = (reinterpret_cast<LocalResponseNormParameter *>(opParameter))->bias_; | |||||
| alpha_ = (reinterpret_cast<LocalResponseNormParameter *>(opParameter))->alpha_; | |||||
| beta_ = (reinterpret_cast<LocalResponseNormParameter *>(opParameter))->beta_; | |||||
| return RET_OK; | |||||
| } | |||||
| int LocalResponseNormCPUKernel::Init() { return RET_OK; } | |||||
| int LocalResponseNormCPUKernel::ReSize() { return RET_OK; } | int LocalResponseNormCPUKernel::ReSize() { return RET_OK; } | ||||
| @@ -60,7 +54,8 @@ int LocalResponseNormCPUKernel::DoLocalResponseNorm(int task_id) { | |||||
| input_ptr += stride * task_id * channel; | input_ptr += stride * task_id * channel; | ||||
| output_ptr += stride * task_id * channel; | output_ptr += stride * task_id * channel; | ||||
| auto error_code = LocalResponseNorm(input_ptr, count, channel, output_ptr, depth_radius_, bias_, alpha_, beta_); | |||||
| auto error_code = LocalResponseNorm(input_ptr, count, channel, output_ptr, | |||||
| reinterpret_cast<LocalResponseNormParameter *>(opParameter)); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "DoLocalResponseNorm error task_id[" << task_id << "] error_code[" << error_code << "]"; | MS_LOG(ERROR) << "DoLocalResponseNorm error task_id[" << task_id << "] error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -36,10 +36,6 @@ class LocalResponseNormCPUKernel : public LiteKernel { | |||||
| private: | private: | ||||
| int thread_count_; | int thread_count_; | ||||
| int depth_radius_; | |||||
| float bias_; | |||||
| float alpha_; | |||||
| float beta_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -30,21 +30,12 @@ using mindspore::schema::PrimitiveType_SoftMax; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int SoftmaxCPUKernel::Init() { | int SoftmaxCPUKernel::Init() { | ||||
| auto input_tensor = inputs_.front(); | |||||
| auto in_shape = input_tensor->shape(); | |||||
| auto in_dims = in_shape.size(); | |||||
| int ele_size = 1; | |||||
| (reinterpret_cast<SoftmaxParameter *>(opParameter))->n_dim_ = in_dims; | |||||
| for (size_t i = 0; i < in_dims; i++) { | |||||
| (reinterpret_cast<SoftmaxParameter *>(opParameter))->input_shape_[i] = in_shape[i]; | |||||
| ele_size *= in_shape[i]; | |||||
| } | |||||
| (reinterpret_cast<SoftmaxParameter *>(opParameter))->element_size_ = ele_size; | |||||
| SoftmaxBaseCPUKernel::Init(); | |||||
| // malloc tmp buffer | // malloc tmp buffer | ||||
| auto axis = reinterpret_cast<SoftmaxParameter *>(opParameter)->axis_; | |||||
| sum_data = reinterpret_cast<float *>(malloc(in_shape[axis] * sizeof(float))); | |||||
| memset(sum_data, 0, in_shape[axis] * sizeof(float)); | |||||
| auto axis = softmax_param_->axis_; | |||||
| sum_data = reinterpret_cast<float *>(malloc(softmax_param_->input_shape_[axis] * sizeof(float))); | |||||
| memset(sum_data, 0, softmax_param_->input_shape_[axis] * sizeof(float)); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -53,31 +44,8 @@ int SoftmaxCPUKernel::ReSize() { return RET_OK; } | |||||
| int SoftmaxCPUKernel::Run() { | int SoftmaxCPUKernel::Run() { | ||||
| auto input_ptr = reinterpret_cast<float *>(inputs_.at(kInputIndex)->Data()); | auto input_ptr = reinterpret_cast<float *>(inputs_.at(kInputIndex)->Data()); | ||||
| auto output_ptr = reinterpret_cast<float *>(outputs_.at(kOutputIndex)->Data()); | auto output_ptr = reinterpret_cast<float *>(outputs_.at(kOutputIndex)->Data()); | ||||
| Softmax(input_ptr, output_ptr, sum_data, reinterpret_cast<SoftmaxParameter *>(opParameter)); | |||||
| Softmax(input_ptr, output_ptr, sum_data, softmax_param_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| kernel::LiteKernel *CpuSoftmaxFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); | |||||
| auto *kernel = new (std::nothrow) SoftmaxCPUKernel(opParameter, inputs, outputs); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftMax, CpuSoftmaxFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -19,14 +19,14 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/runtime/kernel/arm/base/softmax_base.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class SoftmaxCPUKernel : public LiteKernel { | |||||
| class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel { | |||||
| public: | public: | ||||
| SoftmaxCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | SoftmaxCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(parameter, inputs, outputs) {} | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx) | |||||
| : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} | |||||
| ~SoftmaxCPUKernel() override = default; | ~SoftmaxCPUKernel() override = default; | ||||
| int Init() override; | int Init() override; | ||||
| @@ -39,4 +39,3 @@ class SoftmaxCPUKernel : public LiteKernel { | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SOFTMAX_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SOFTMAX_H_ | ||||
| @@ -0,0 +1,111 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/int8/softmax_int8.h" | |||||
| #include "src/runtime/kernel/arm/opclib/int8/softmax_int8.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| namespace mindspore::kernel { | |||||
| int SoftmaxInt8CPUKernel::Init() { | |||||
| SoftmaxBaseCPUKernel::Init(); | |||||
| auto *input_tensor = inputs_.at(kInputIndex); | |||||
| MS_ASSERT(input_tensor); | |||||
| auto in_quant_args = input_tensor->GetQuantParams(); | |||||
| quant_params_.in_quant_args_.scale_ = in_quant_args.front().scale; | |||||
| quant_params_.in_quant_args_.zp_ = in_quant_args.front().zeroPoint; | |||||
| auto *out_tensor = outputs_.at(kOutputIndex); | |||||
| MS_ASSERT(out_tensor); | |||||
| auto out_quant_args = out_tensor->GetQuantParams(); | |||||
| quant_params_.out_quant_arg_.scale_ = out_quant_args.front().scale; | |||||
| quant_params_.out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; | |||||
| int inner_size = 1; | |||||
| for (int i = softmax_param_->axis_ + 1; i < softmax_param_->n_dim_; i++) { | |||||
| inner_size *= softmax_param_->input_shape_[i]; | |||||
| } | |||||
| exp_data_ = reinterpret_cast<float *>(malloc(softmax_param_->element_size_ * sizeof(float))); | |||||
| sum_data_ = reinterpret_cast<float *>(malloc(inner_size * sizeof(float))); | |||||
| return RET_OK; | |||||
| } | |||||
| int SoftmaxInt8CPUKernel::ReSize() { return RET_OK; } | |||||
| int SoftmaxInt8CPUKernel::DoSoftmax(int task_id) { | |||||
| MS_ASSERT(inputs_.size() == 1); | |||||
| MS_ASSERT(outputs_.size() == 1); | |||||
| auto input_ptr = reinterpret_cast<int8_t *>(inputs_.at(0)->Data()); | |||||
| auto output_ptr = reinterpret_cast<int8_t *>(outputs_.at(0)->Data()); | |||||
| int outter_size = 1, inner_size = 1; | |||||
| for (int i = 0; i < softmax_param_->axis_; i++) { | |||||
| outter_size *= softmax_param_->input_shape_[i]; | |||||
| } | |||||
| for (int i = softmax_param_->axis_; i < softmax_param_->n_dim_; i++) { | |||||
| inner_size *= softmax_param_->input_shape_[i]; | |||||
| } | |||||
| int stride = UP_DIV(outter_size, thread_count_); | |||||
| int count = MSMIN(stride, outter_size - stride * task_id); | |||||
| input_ptr += stride * task_id * inner_size; | |||||
| output_ptr += stride * task_id * inner_size; | |||||
| exp_data_ += stride * task_id * inner_size; | |||||
| auto error_code = Softmax(input_ptr, output_ptr, count, exp_data_, sum_data_, quant_params_, softmax_param_); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoSoftmax error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int SoftmaxRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| auto softmax_kernel = reinterpret_cast<SoftmaxInt8CPUKernel *>(cdata); | |||||
| auto error_code = softmax_kernel->DoSoftmax(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "SoftmaxRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int SoftmaxInt8CPUKernel::Run() { | |||||
| auto input_ptr = reinterpret_cast<int8_t *>(inputs_.at(0)->Data()); | |||||
| int ele_size = softmax_param_->element_size_; | |||||
| for (int i = 0; i < ele_size; i++) { | |||||
| float input_scaled = ((input_ptr[i] - quant_params_.in_quant_args_.zp_) * quant_params_.in_quant_args_.scale_); | |||||
| exp_data_[i] = exp(input_scaled); | |||||
| } | |||||
| int error_code = LiteBackendParallelLaunch(SoftmaxRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "Softmax function error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SOFTMAX_INT8_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SOFTMAX_INT8_H_ | |||||
| #include <vector> | |||||
| #include "src/runtime/kernel/arm/base/softmax_base.h" | |||||
| namespace mindspore::kernel { | |||||
| class SoftmaxInt8CPUKernel : public SoftmaxBaseCPUKernel { | |||||
| public: | |||||
| SoftmaxInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx) | |||||
| : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} | |||||
| ~SoftmaxInt8CPUKernel() = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int DoSoftmax(int task_id); | |||||
| private: | |||||
| float *sum_data_; | |||||
| float *exp_data_; | |||||
| SoftmaxQuantArg quant_params_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SOFTMAX_INT8_H_ | |||||
| @@ -16,11 +16,16 @@ | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/local_response_norm.h" | #include "src/runtime/kernel/arm/opclib/fp32/local_response_norm.h" | ||||
| int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, int depth_radius, float bias, | |||||
| float alpha, float beta) { | |||||
| int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, | |||||
| LocalResponseNormParameter *param) { | |||||
| int i, j, k; | int i, j, k; | ||||
| int left, right; | int left, right; | ||||
| float depth_radius = param->depth_radius_; | |||||
| float bias = param->bias_; | |||||
| float alpha = param->alpha_; | |||||
| float beta = param->beta_; | |||||
| for (i = 0; i < out_size; i++) { | for (i = 0; i < out_size; i++) { | ||||
| float *in_data = input_ptr + i * channel; | float *in_data = input_ptr + i * channel; | ||||
| float *out_data = output_ptr + i * channel; | float *out_data = output_ptr + i * channel; | ||||
| @@ -39,4 +44,3 @@ int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output | |||||
| } | } | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -27,8 +27,7 @@ struct LocalResponseNormParameter { | |||||
| float beta_; | float beta_; | ||||
| }; | }; | ||||
| int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, int depth_radius, float bias, | |||||
| float alpha, float beta); | |||||
| int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, | |||||
| LocalResponseNormParameter *param); | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_LOCAL_RESPONSE_NORM_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_LOCAL_RESPONSE_NORM_H_ | ||||
| @@ -18,17 +18,8 @@ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_H_ | #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_H_ | ||||
| #include "src/runtime/kernel/arm/opclib/op_base.h" | #include "src/runtime/kernel/arm/opclib/op_base.h" | ||||
| struct SoftmaxParameter { | |||||
| OpParameter op_parameter_; | |||||
| int32_t axis_; | |||||
| int element_size_; | |||||
| int n_dim_; | |||||
| int input_shape_[4]; | |||||
| }; | |||||
| #include "src/runtime/kernel/arm/opclib/softmax_parameter.h" | |||||
| void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter); | void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter); | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_H_ | ||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/opclib/int8/softmax_int8.h" | |||||
| #include <cmath> | |||||
| int Softmax(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data, | |||||
| SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) { | |||||
| int32_t axis = parameter->axis_; | |||||
| int n_dim = parameter->n_dim_; | |||||
| int *input_shape = parameter->input_shape_; | |||||
| int axis_shape_size = input_shape[axis]; | |||||
| double output_scale = quant_param.out_quant_arg_.scale_; | |||||
| int32_t output_zp = quant_param.out_quant_arg_.zp_; | |||||
| int inner_size = 1; | |||||
| for (int i = axis + 1; i < n_dim; i++) { | |||||
| inner_size *= input_shape[i]; | |||||
| } | |||||
| for (int o = 0; o < count; o++) { | |||||
| int outter_offset = o * axis_shape_size * inner_size; | |||||
| for (int i = 0; i < inner_size; i++) { | |||||
| float sum = 0; | |||||
| for (int j = 0; j < axis_shape_size; j++) { | |||||
| int axis_offset = outter_offset + i + j * inner_size; | |||||
| sum += exp_data[axis_offset]; | |||||
| } | |||||
| sum_data[i] = sum; | |||||
| } | |||||
| for (int j = 0; j < axis_shape_size; j++) { | |||||
| int axis_offset = outter_offset + j * inner_size; | |||||
| for (int i = 0; i < inner_size; i++) { | |||||
| int inner_offset = axis_offset + i; | |||||
| float real_output = exp_data[inner_offset] / sum_data[i]; | |||||
| int32_t output_scaled = round(real_output / output_scale) + output_zp; | |||||
| output_ptr[inner_offset] = MSMAX(CHAR_MIN, MSMIN(CHAR_MAX, output_scaled)); | |||||
| } | |||||
| } | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SOFTMAX_INT8_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SOFTMAX_INT8_H_ | |||||
| #include "src/runtime/kernel/arm/opclib/op_base.h" | |||||
| #include "src/runtime/kernel/arm/opclib/softmax_parameter.h" | |||||
| int Softmax(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data, | |||||
| SoftmaxQuantArg quant_param, SoftmaxParameter *parameter); | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SOFTMAX_INT8_H_ | |||||
| @@ -91,6 +91,11 @@ struct ArithSelfQuantArg { | |||||
| int output_activation_max_; | int output_activation_max_; | ||||
| }; | }; | ||||
| struct SoftmaxQuantArg { | |||||
| QuantArg in_quant_args_; | |||||
| QuantArg out_quant_arg_; | |||||
| }; | |||||
| void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift); | void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift); | ||||
| inline void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, | inline void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, | ||||
| @@ -0,0 +1,30 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SOFTMAX_PARAMETER_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SOFTMAX_PARAMETER_H_ | |||||
| #include "src/runtime/kernel/arm/opclib/op_base.h" | |||||
| struct SoftmaxParameter { | |||||
| OpParameter op_parameter_; | |||||
| int32_t axis_; | |||||
| int element_size_; | |||||
| int n_dim_; | |||||
| int input_shape_[4]; | |||||
| }; | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SOFTMAX_PARAMETER_H_ | |||||
| @@ -0,0 +1,92 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/opclib/softmax_parameter.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | |||||
| namespace mindspore { | |||||
| class TestSoftmaxInt8 : public mindspore::Common { | |||||
| public: | |||||
| TestSoftmaxInt8() {} | |||||
| }; | |||||
| TEST_F(TestSoftmaxInt8, SoftmaxInt8) { | |||||
| std::vector<lite::tensor::Tensor *> inputs_tensor; | |||||
| std::vector<lite::tensor::Tensor *> outputs_tensor; | |||||
| SoftmaxParameter op_param; | |||||
| op_param.op_parameter_.type_ = schema::PrimitiveType_SoftMax; | |||||
| op_param.axis_ = 2; | |||||
| op_param.element_size_ = 24; | |||||
| op_param.input_shape_[0] = 1; | |||||
| op_param.input_shape_[1] = 2; | |||||
| op_param.input_shape_[2] = 3; | |||||
| op_param.input_shape_[3] = 4; | |||||
| lite::tensor::QuantArg input_quant_arg; | |||||
| input_quant_arg.scale = 0.0352941; | |||||
| input_quant_arg.zeroPoint = -128; | |||||
| lite::tensor::QuantArg output_quant_arg; | |||||
| output_quant_arg.scale = 0.00392157; | |||||
| output_quant_arg.zeroPoint = -128; | |||||
| std::vector<int8_t> input = {-71, -43, -15, 14, -43, -15, 14, 42, 70, 99, 99, 127, | |||||
| -100, -71, -43, -15, 14, 42, 70, 99, 42, 70, 99, 127}; | |||||
| std::vector<int> in_shape = {1, 2, 3, 4}; | |||||
| lite::tensor::Tensor input0_tensor; | |||||
| TypeId tid_int8 = kNumberTypeInt8; | |||||
| inputs_tensor.push_back(&input0_tensor); | |||||
| input0_tensor.SetData(input.data()); | |||||
| input0_tensor.set_shape(in_shape); | |||||
| input0_tensor.AddQuantParam(input_quant_arg); | |||||
| input0_tensor.set_data_type(tid_int8); | |||||
| std::vector<int8_t> output(24); | |||||
| std::vector<int> output_shape = {1, 2, 3, 4}; | |||||
| lite::tensor::Tensor output0_tensor; | |||||
| outputs_tensor.push_back(&output0_tensor); | |||||
| output0_tensor.SetData(output.data()); | |||||
| output0_tensor.AddQuantParam(output_quant_arg); | |||||
| output0_tensor.set_data_type(tid_int8); | |||||
| auto ctx = std::make_shared<lite::Context>(); | |||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_SoftMax}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| ASSERT_NE(creator, nullptr); | |||||
| kernel::LiteKernel *kernel = | |||||
| creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), ctx.get(), desc); | |||||
| ASSERT_NE(kernel, nullptr); | |||||
| auto output_tensor_shape = output0_tensor.shape(); | |||||
| kernel->Run(); | |||||
| std::vector<int8_t> except_result = {-126, -126, -124, -124, -123, -124, -116, -116, 121, 121, 111, 111, | |||||
| -127, -127, -127, -127, -59, -59, -61, -59, 57, 57, 59, 57}; | |||||
| CompareOutputData(output.data(), except_result.data(), input.size(), 0.000001); | |||||
| input0_tensor.SetData(nullptr); | |||||
| output0_tensor.SetData(nullptr); | |||||
| } | |||||
| } // namespace mindspore | |||||