diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index f9a130b97d..8dfe8be5ef 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -32,7 +32,7 @@ #include "src/runtime/kernel/arm/opclib/conv_parameter.h" #include "src/runtime/kernel/arm/opclib/fp32/pooling.h" #include "src/runtime/kernel/arm/opclib/matmul.h" -#include "src/runtime/kernel/arm/opclib/fp32/softmax.h" +#include "src/runtime/kernel/arm/opclib/softmax_parameter.h" #include "src/runtime/kernel/arm/opclib/tile.h" #include "src/runtime/kernel/arm/opclib/fp32/topk.h" #include "src/runtime/kernel/arm/opclib/fp32/reduce.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc new file mode 100644 index 0000000000..abe6136ab8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/softmax_base.h" +#include +#include "src/runtime/kernel/arm/int8/softmax_int8.h" +#include "src/runtime/kernel/arm/fp32/softmax.h" +#include "src/runtime/kernel/arm/opclib/fp32/softmax.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::lite::RET_NULL_PTR; +using mindspore::schema::PrimitiveType_SoftMax; + +namespace mindspore::kernel { + +int SoftmaxBaseCPUKernel::Init() { + if (softmax_param_ == nullptr) { + MS_LOG(ERROR) << "SoftmaxParameter nullptr"; + return RET_NULL_PTR; + } + + auto input_tensor = inputs_.front(); + auto in_shape = input_tensor->shape(); + auto in_dims = in_shape.size(); + int ele_size = 1; + softmax_param_->n_dim_ = in_dims; + for (size_t i = 0; i < in_dims; i++) { + softmax_param_->input_shape_[i] = in_shape[i]; + ele_size *= in_shape[i]; + } + softmax_param_->element_size_ = ele_size; + return RET_OK; +} + +kernel::LiteKernel *CpuSoftmaxInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); + auto *kernel = new (std::nothrow) SoftmaxInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuSoftmaxFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); + auto *kernel = new (std::nothrow) SoftmaxCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_SoftMax, CpuSoftmaxInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftMax, CpuSoftmaxFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h new file mode 100644 index 0000000000..d2ac003ccf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SOFTMAX_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SOFTMAX_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/softmax_parameter.h" + +namespace mindspore::kernel { +class SoftmaxBaseCPUKernel : public LiteKernel { + public: + SoftmaxBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + opParameter->thread_num_ = ctx->thread_num_; + softmax_param_ = reinterpret_cast(opParameter); + } + ~SoftmaxBaseCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override { return 0; } + + protected: + int thread_count_; + const lite::Context *ctx_; + SoftmaxParameter *softmax_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SOFTMAX_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc index 960ae8edec..bb1286eebd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc @@ -29,13 +29,7 @@ using mindspore::schema::PrimitiveType_LocalResponseNormalization; namespace mindspore::kernel { -int LocalResponseNormCPUKernel::Init() { - depth_radius_ = (reinterpret_cast(opParameter))->depth_radius_; - bias_ = (reinterpret_cast(opParameter))->bias_; - alpha_ = (reinterpret_cast(opParameter))->alpha_; - beta_ = (reinterpret_cast(opParameter))->beta_; - return RET_OK; -} +int LocalResponseNormCPUKernel::Init() { return RET_OK; } int LocalResponseNormCPUKernel::ReSize() { return RET_OK; } @@ -60,7 +54,8 @@ int LocalResponseNormCPUKernel::DoLocalResponseNorm(int task_id) { input_ptr += stride * task_id * channel; output_ptr += stride * task_id * channel; - auto error_code = LocalResponseNorm(input_ptr, count, channel, output_ptr, depth_radius_, bias_, alpha_, beta_); + auto error_code = LocalResponseNorm(input_ptr, count, channel, output_ptr, + reinterpret_cast(opParameter)); if (error_code != RET_OK) { MS_LOG(ERROR) << "DoLocalResponseNorm error task_id[" << task_id << "] error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h index d5839d1037..d85216970d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h @@ -36,10 +36,6 @@ class LocalResponseNormCPUKernel : public LiteKernel { private: int thread_count_; - int depth_radius_; - float bias_; - float alpha_; - float beta_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc index f59ab4c4b5..2eb8129d6d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc @@ -30,21 +30,12 @@ using mindspore::schema::PrimitiveType_SoftMax; namespace mindspore::kernel { int SoftmaxCPUKernel::Init() { - auto input_tensor = inputs_.front(); - auto in_shape = input_tensor->shape(); - auto in_dims = in_shape.size(); - int ele_size = 1; - (reinterpret_cast(opParameter))->n_dim_ = in_dims; - for (size_t i = 0; i < in_dims; i++) { - (reinterpret_cast(opParameter))->input_shape_[i] = in_shape[i]; - ele_size *= in_shape[i]; - } - (reinterpret_cast(opParameter))->element_size_ = ele_size; + SoftmaxBaseCPUKernel::Init(); // malloc tmp buffer - auto axis = reinterpret_cast(opParameter)->axis_; - sum_data = reinterpret_cast(malloc(in_shape[axis] * sizeof(float))); - memset(sum_data, 0, in_shape[axis] * sizeof(float)); + auto axis = softmax_param_->axis_; + sum_data = reinterpret_cast(malloc(softmax_param_->input_shape_[axis] * sizeof(float))); + memset(sum_data, 0, softmax_param_->input_shape_[axis] * sizeof(float)); return RET_OK; } @@ -53,31 +44,8 @@ int SoftmaxCPUKernel::ReSize() { return RET_OK; } int SoftmaxCPUKernel::Run() { auto input_ptr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); auto output_ptr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); - Softmax(input_ptr, output_ptr, sum_data, reinterpret_cast(opParameter)); + Softmax(input_ptr, output_ptr, sum_data, softmax_param_); return RET_OK; } -kernel::LiteKernel *CpuSoftmaxFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); - auto *kernel = new (std::nothrow) SoftmaxCPUKernel(opParameter, inputs, outputs); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!"; - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - delete kernel; - return nullptr; - } - return kernel; -} - -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftMax, CpuSoftmaxFp32KernelCreator) } // namespace mindspore::kernel - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h index 61c1951af3..0e9c0a7ebf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h @@ -19,14 +19,14 @@ #include #include "src/lite_kernel.h" - +#include "src/runtime/kernel/arm/base/softmax_base.h" namespace mindspore::kernel { -class SoftmaxCPUKernel : public LiteKernel { +class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel { public: SoftmaxCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx) + : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} ~SoftmaxCPUKernel() override = default; int Init() override; @@ -39,4 +39,3 @@ class SoftmaxCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SOFTMAX_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc new file mode 100644 index 0000000000..a78a9bc500 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/softmax_int8.h" +#include "src/runtime/kernel/arm/opclib/int8/softmax_int8.h" +#include "schema/model_generated.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { + +int SoftmaxInt8CPUKernel::Init() { + SoftmaxBaseCPUKernel::Init(); + + auto *input_tensor = inputs_.at(kInputIndex); + MS_ASSERT(input_tensor); + + auto in_quant_args = input_tensor->GetQuantParams(); + quant_params_.in_quant_args_.scale_ = in_quant_args.front().scale; + quant_params_.in_quant_args_.zp_ = in_quant_args.front().zeroPoint; + + auto *out_tensor = outputs_.at(kOutputIndex); + MS_ASSERT(out_tensor); + + auto out_quant_args = out_tensor->GetQuantParams(); + quant_params_.out_quant_arg_.scale_ = out_quant_args.front().scale; + quant_params_.out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; + + int inner_size = 1; + for (int i = softmax_param_->axis_ + 1; i < softmax_param_->n_dim_; i++) { + inner_size *= softmax_param_->input_shape_[i]; + } + + exp_data_ = reinterpret_cast(malloc(softmax_param_->element_size_ * sizeof(float))); + sum_data_ = reinterpret_cast(malloc(inner_size * sizeof(float))); + return RET_OK; +} + +int SoftmaxInt8CPUKernel::ReSize() { return RET_OK; } + +int SoftmaxInt8CPUKernel::DoSoftmax(int task_id) { + MS_ASSERT(inputs_.size() == 1); + MS_ASSERT(outputs_.size() == 1); + + auto input_ptr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + + int outter_size = 1, inner_size = 1; + for (int i = 0; i < softmax_param_->axis_; i++) { + outter_size *= softmax_param_->input_shape_[i]; + } + for (int i = softmax_param_->axis_; i < softmax_param_->n_dim_; i++) { + inner_size *= softmax_param_->input_shape_[i]; + } + + int stride = UP_DIV(outter_size, thread_count_); + int count = MSMIN(stride, outter_size - stride * task_id); + + input_ptr += stride * task_id * inner_size; + output_ptr += stride * task_id * inner_size; + exp_data_ += stride * task_id * inner_size; + + auto error_code = Softmax(input_ptr, output_ptr, count, exp_data_, sum_data_, quant_params_, softmax_param_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DoSoftmax error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SoftmaxRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto softmax_kernel = reinterpret_cast(cdata); + auto error_code = softmax_kernel->DoSoftmax(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "SoftmaxRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SoftmaxInt8CPUKernel::Run() { + auto input_ptr = reinterpret_cast(inputs_.at(0)->Data()); + int ele_size = softmax_param_->element_size_; + for (int i = 0; i < ele_size; i++) { + float input_scaled = ((input_ptr[i] - quant_params_.in_quant_args_.zp_) * quant_params_.in_quant_args_.scale_); + exp_data_[i] = exp(input_scaled); + } + int error_code = LiteBackendParallelLaunch(SoftmaxRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Softmax function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h new file mode 100644 index 0000000000..d29da4094a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SOFTMAX_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SOFTMAX_INT8_H_ + +#include +#include "src/runtime/kernel/arm/base/softmax_base.h" + +namespace mindspore::kernel { +class SoftmaxInt8CPUKernel : public SoftmaxBaseCPUKernel { + public: + SoftmaxInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~SoftmaxInt8CPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoSoftmax(int task_id); + + private: + float *sum_data_; + float *exp_data_; + SoftmaxQuantArg quant_params_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SOFTMAX_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.cc index 460f634b81..3beb4c44d7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.cc @@ -16,11 +16,16 @@ #include "src/runtime/kernel/arm/opclib/fp32/local_response_norm.h" -int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, int depth_radius, float bias, - float alpha, float beta) { +int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, + LocalResponseNormParameter *param) { int i, j, k; int left, right; + float depth_radius = param->depth_radius_; + float bias = param->bias_; + float alpha = param->alpha_; + float beta = param->beta_; + for (i = 0; i < out_size; i++) { float *in_data = input_ptr + i * channel; float *out_data = output_ptr + i * channel; @@ -39,4 +44,3 @@ int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output } return 0; } - diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.h index f3a989343d..b2631427cf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.h @@ -27,8 +27,7 @@ struct LocalResponseNormParameter { float beta_; }; -int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, int depth_radius, float bias, - float alpha, float beta); +int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, + LocalResponseNormParameter *param); #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_LOCAL_RESPONSE_NORM_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h index 0787e1682a..590b020495 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h @@ -18,17 +18,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_H_ #include "src/runtime/kernel/arm/opclib/op_base.h" - -struct SoftmaxParameter { - OpParameter op_parameter_; - int32_t axis_; - int element_size_; - int n_dim_; - int input_shape_[4]; -}; +#include "src/runtime/kernel/arm/opclib/softmax_parameter.h" void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter); - #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/softmax_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/softmax_int8.cc new file mode 100644 index 0000000000..0b94d1c1b0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/softmax_int8.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/int8/softmax_int8.h" +#include + +int Softmax(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data, + SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) { + int32_t axis = parameter->axis_; + int n_dim = parameter->n_dim_; + int *input_shape = parameter->input_shape_; + int axis_shape_size = input_shape[axis]; + + double output_scale = quant_param.out_quant_arg_.scale_; + int32_t output_zp = quant_param.out_quant_arg_.zp_; + + int inner_size = 1; + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + + for (int o = 0; o < count; o++) { + int outter_offset = o * axis_shape_size * inner_size; + for (int i = 0; i < inner_size; i++) { + float sum = 0; + for (int j = 0; j < axis_shape_size; j++) { + int axis_offset = outter_offset + i + j * inner_size; + sum += exp_data[axis_offset]; + } + sum_data[i] = sum; + } + for (int j = 0; j < axis_shape_size; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int i = 0; i < inner_size; i++) { + int inner_offset = axis_offset + i; + float real_output = exp_data[inner_offset] / sum_data[i]; + int32_t output_scaled = round(real_output / output_scale) + output_zp; + output_ptr[inner_offset] = MSMAX(CHAR_MIN, MSMIN(CHAR_MAX, output_scaled)); + } + } + } + return 0; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/softmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/softmax_int8.h new file mode 100644 index 0000000000..cf5e03564d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/softmax_int8.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SOFTMAX_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SOFTMAX_INT8_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/softmax_parameter.h" + +int Softmax(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data, + SoftmaxQuantArg quant_param, SoftmaxParameter *parameter); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SOFTMAX_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h index fb53914cd9..8c6065cf57 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h @@ -101,6 +101,11 @@ struct SplitQuantArg { int output_activation_max_; }; +struct SoftmaxQuantArg { + QuantArg in_quant_args_; + QuantArg out_quant_arg_; +}; + void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift); inline void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/softmax_parameter.h b/mindspore/lite/src/runtime/kernel/arm/opclib/softmax_parameter.h new file mode 100644 index 0000000000..23c3424f94 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/softmax_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SOFTMAX_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SOFTMAX_PARAMETER_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct SoftmaxParameter { + OpParameter op_parameter_; + int32_t axis_; + int element_size_; + int n_dim_; + int input_shape_[4]; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SOFTMAX_PARAMETER_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc new file mode 100644 index 0000000000..3cab67205a --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h" +#include "mindspore/lite/src/runtime/kernel/arm/opclib/softmax_parameter.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { + +class TestSoftmaxInt8 : public mindspore::Common { + public: + TestSoftmaxInt8() {} +}; + +TEST_F(TestSoftmaxInt8, SoftmaxInt8) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + SoftmaxParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_SoftMax; + op_param.axis_ = 2; + op_param.element_size_ = 24; + op_param.input_shape_[0] = 1; + op_param.input_shape_[1] = 2; + op_param.input_shape_[2] = 3; + op_param.input_shape_[3] = 4; + + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 0.0352941; + input_quant_arg.zeroPoint = -128; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 0.00392157; + output_quant_arg.zeroPoint = -128; + + std::vector input = {-71, -43, -15, 14, -43, -15, 14, 42, 70, 99, 99, 127, + -100, -71, -43, -15, 14, 42, 70, 99, 42, 70, 99, 127}; + std::vector in_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor input0_tensor; + TypeId tid_int8 = kNumberTypeInt8; + inputs_tensor.push_back(&input0_tensor); + input0_tensor.SetData(input.data()); + input0_tensor.set_shape(in_shape); + input0_tensor.AddQuantParam(input_quant_arg); + input0_tensor.set_data_type(tid_int8); + + std::vector output(24); + std::vector output_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + output0_tensor.AddQuantParam(output_quant_arg); + output0_tensor.set_data_type(tid_int8); + + auto ctx = std::make_shared(); + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_SoftMax}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx.get(), desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector except_result = {-126, -126, -124, -124, -123, -124, -116, -116, 121, 121, 111, 111, + -127, -127, -127, -127, -59, -59, -61, -59, 57, 57, 59, 57}; + + CompareOutputData(output.data(), except_result.data(), input.size(), 0.000001); + + input0_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} + +} // namespace mindspore