Browse Source

!8467 [MSLITE] support layernorm int8

From: @ling_qiao_min
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3518c7f259
6 changed files with 72 additions and 66 deletions
  1. +17
    -23
      mindspore/lite/nnacl/int8/layer_norm_int8.c
  2. +2
    -2
      mindspore/lite/nnacl/int8/layer_norm_int8.h
  3. +4
    -10
      mindspore/lite/nnacl/layer_norm_parameter.h
  4. +44
    -27
      mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.cc
  5. +4
    -4
      mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.h
  6. +1
    -0
      mindspore/lite/tools/common/node_util.cc

+ 17
- 23
mindspore/lite/nnacl/int8/layer_norm_int8.c View File

@@ -21,8 +21,8 @@
* quant : (x-mean) / sqrt(sum(x * x) - mean * mean) * gamma + beta
*
* */
int LayerNormInt8(const int8_t *src_data, const int8_t *gamma_data, const int32_t *beta_data, int8_t *dst_data,
bool affine, int outer_size, int inner_size, LayerNormQuantArg *quant_) {
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
bool affine, int outer_size, int inner_size, LayerNormQuantArg *quant, float epsilon) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
@@ -34,30 +34,24 @@ int LayerNormInt8(const int8_t *src_data, const int8_t *gamma_data, const int32_
for (int out_index = 0; out_index < outer_size; out_index++) {
const int8_t *src = src_data + out_index * inner_size;
int8_t *dst = dst_data + out_index * inner_size;
int32_t mean = 0;
int32_t square_mean = 0;
for (int in_index = 0; in_index < inner_size; in_index++) {
int32_t tmp_src = src[in_index] - quant_->in_quant_arg_.zp_;
mean += tmp_src;
square_mean += tmp_src * tmp_src;
float mean = 0.0f;
float square_mean = 0.0f;
for (int i = 0; i < inner_size; i++) {
float float_src = (src[i] - quant->in_zp_) * quant->in_scale_;
mean += float_src;
square_mean += float_src * float_src;
}
mean = round(mean / inner_size);
square_mean = round(square_mean / inner_size);

int32_t variance_value = square_mean - mean * mean;

int32_t multiplier;
int32_t shift;
GetSqrtQuantMultiplierExp(variance_value, -1, &multiplier, &shift);

for (int in_index = 0; in_index < inner_size; in_index++) {
int32_t in = src[in_index] - quant_->in_quant_arg_.zp_ - mean;
int32_t tmp = RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(in * (1 << 7), multiplier), -shift);
mean /= (float)inner_size;
square_mean /= (float)inner_size;
const float deno = 1 / sqrtf(square_mean - mean * mean + epsilon);
for (int i = 0; i < inner_size; i++) {
float fp32_src = (src[i] - quant->in_zp_) * quant->in_scale_;
float fp32_dst = (fp32_src - mean) * deno;
if (affine) {
tmp = tmp * (gamma_data[in_index] - quant_->gamma_quant_arg_.zp_) + beta_data[in_index];
fp32_dst = fp32_dst * gamma_data[i] + beta_data[i];
}
int32_t out = MultiplyByQuantizedMultiplier(tmp, quant_->multiplier_, quant_->shift_left_, quant_->shift_right_);
dst[in_index] = (int8_t)MSMIN(quant_->output_activation_max_, MSMAX(quant_->output_activation_max_, out));
int32_t int32_dst = (int32_t)round(fp32_dst * 1.0 / quant->out_scale_ + quant->out_zp_);
dst[i] = (int8_t)MSMAX(MSMIN(int32_dst, 127), -128);
}
}
return NNACL_OK;


+ 2
- 2
mindspore/lite/nnacl/int8/layer_norm_int8.h View File

@@ -24,8 +24,8 @@
extern "C" {
#endif

int LayerNormInt8(const int8_t *src_data, const int8_t *gamma_data, const int32_t *beta_data, int8_t *dst_data,
bool affine, int outer_size, int inner_size, LayerNormQuantArg *quant_);
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
bool affine, int outer_size, int inner_size, LayerNormQuantArg *quant, float epsilon);

#ifdef __cplusplus
}


+ 4
- 10
mindspore/lite/nnacl/layer_norm_parameter.h View File

@@ -30,16 +30,10 @@ typedef struct LayerNormParameter {
} LayerNormParameter;

typedef struct LayerNormQuantArg {
QuantArg in_quant_arg_;
QuantArg out_quant_arg_;
QuantArg gamma_quant_arg_;

int32_t multiplier_;
int32_t shift_left_;
int32_t shift_right_;

int output_activation_min_;
int output_activation_max_;
int32_t in_zp_;
int32_t out_zp_;
double in_scale_;
double out_scale_;
} LayerNormQuantArg;

#endif // MINDSPORE_LITE_NNACL_LAYER_NORM_PARAMETER_H_

+ 44
- 27
mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.cc View File

@@ -22,35 +22,56 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LayerNorm;

namespace mindspore::kernel {
void LayerNormInt8CPUKernel::SetQuantArgs() {
LayerNormInt8CPUKernel::~LayerNormInt8CPUKernel() {
if (param_->elementwise_affine_ && gamma_ptr_ != nullptr) {
free(gamma_ptr_);
gamma_ptr_ = nullptr;
}
if (param_->elementwise_affine_ && beta_ptr_ != nullptr) {
free(beta_ptr_);
beta_ptr_ = nullptr;
}
return;
}

int LayerNormInt8CPUKernel::SetQuantArgs() {
lite::Tensor *input = in_tensors_.at(0);
lite::Tensor *output = out_tensors_.at(0);

quant_param_.in_quant_arg_.zp_ = input->GetQuantParams().front().zeroPoint;
quant_param_.in_quant_arg_.scale_ = input->GetQuantParams().front().scale;
quant_param_.out_quant_arg_.zp_ = output->GetQuantParams().front().zeroPoint;
quant_param_.out_quant_arg_.scale_ = output->GetQuantParams().front().scale;

quant_param_.output_activation_min_ = std::numeric_limits<int8_t>::min();
quant_param_.output_activation_max_ = std::numeric_limits<int8_t>::max();
quant_param_.in_zp_ = input->GetQuantParams().front().zeroPoint;
quant_param_.in_scale_ = input->GetQuantParams().front().scale;
quant_param_.out_zp_ = output->GetQuantParams().front().zeroPoint;
quant_param_.out_scale_ = output->GetQuantParams().front().scale;

if (param_->elementwise_affine_) {
lite::Tensor *gamma_tensor = out_tensors_.at(1);
quant_param_.gamma_quant_arg_.zp_ = gamma_tensor->GetQuantParams().front().zeroPoint;
quant_param_.gamma_quant_arg_.scale_ = gamma_tensor->GetQuantParams().front().scale;
}
lite::Tensor *gamma_tensor = in_tensors_.at(1);
lite::Tensor *beta_tensor = in_tensors_.at(2);

double gamma_scale = gamma_tensor->GetQuantParams().front().scale;
int gamma_zp = gamma_tensor->GetQuantParams().front().zeroPoint;
gamma_ptr_ = reinterpret_cast<float *>(malloc(gamma_tensor->ElementsNum() * sizeof(float)));
if (gamma_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc gamma_ptr_ failed";
return RET_ERROR;
}
int8_t *src_gamma = reinterpret_cast<int8_t *>(gamma_tensor->data_c());
for (int i = 0; i < gamma_tensor->ElementsNum(); i++) {
gamma_ptr_[i] = (src_gamma[i] - gamma_zp) * gamma_scale;
}

double in_scale;
if (param_->elementwise_affine_) {
in_scale = static_cast<double>(quant_param_.in_quant_arg_.scale_ * quant_param_.gamma_quant_arg_.scale_);
} else {
in_scale = static_cast<double>(quant_param_.in_quant_arg_.scale_);
beta_ptr_ = reinterpret_cast<float *>(malloc(beta_tensor->ElementsNum() * sizeof(float)));
if (beta_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc beta_ptr_ failed";
free(gamma_ptr_);
gamma_ptr_ = nullptr;
return RET_ERROR;
}
int32_t *src_beta = reinterpret_cast<int32_t *>(beta_tensor->data_c());
for (int i = 0; i < beta_tensor->ElementsNum(); i++) {
beta_ptr_[i] = src_beta[i] * quant_param_.in_scale_ * gamma_scale;
}
}
double real_multiplier = in_scale / static_cast<double>(quant_param_.out_quant_arg_.scale_);

QuantizeRoundParameter(real_multiplier, &quant_param_.multiplier_, &quant_param_.shift_left_,
&quant_param_.shift_right_);
return;
return RET_OK;
}

int LayerNormInt8CPUKernel::Init() {
@@ -96,17 +117,13 @@ int LayerNormInt8CPUKernel::DoExecute(int task_id) {
int8_t *thread_dst = dst_ptr_ + task_id * param_->thread_outsize_ * inner_size_;

LayerNormInt8(thread_src, gamma_ptr_, beta_ptr_, thread_dst, param_->elementwise_affine_, current_out_size,
inner_size_, &quant_param_);
inner_size_, &quant_param_, param_->epsilon_);
return RET_OK;
}

int LayerNormInt8CPUKernel::Run() {
src_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.at(0)->MutableData());
dst_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->MutableData());
if (param_->elementwise_affine_) {
gamma_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.at(1)->MutableData());
beta_ptr_ = reinterpret_cast<int32_t *>(in_tensors_.at(2)->MutableData());
}

auto ret = ParallelLaunch(this->context_->thread_pool_, LayerNormInt8Run, this, op_parameter_->thread_num_);
if (ret != RET_OK) {


+ 4
- 4
mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.h View File

@@ -31,7 +31,7 @@ class LayerNormInt8CPUKernel : public LiteKernel {
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<LayerNormParameter *>(parameter);
}
~LayerNormInt8CPUKernel() override{};
~LayerNormInt8CPUKernel() override;

int Init() override;
int ReSize() override;
@@ -41,7 +41,7 @@ class LayerNormInt8CPUKernel : public LiteKernel {
int DoExecute(int task_id);

private:
void SetQuantArgs();
int SetQuantArgs();

private:
LayerNormParameter *param_ = nullptr;
@@ -50,8 +50,8 @@ class LayerNormInt8CPUKernel : public LiteKernel {
int inner_size_ = 0;
int8_t *src_ptr_ = nullptr;
int8_t *dst_ptr_ = nullptr;
int8_t *gamma_ptr_ = nullptr;
int32_t *beta_ptr_ = nullptr;
float *gamma_ptr_ = nullptr;
float *beta_ptr_ = nullptr;
};
} // namespace mindspore::kernel



+ 1
- 0
mindspore/lite/tools/common/node_util.cc View File

@@ -138,6 +138,7 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT
schema::PrimitiveType_Crop,
schema::PrimitiveType_PriorBox,
schema::PrimitiveType_QuantDTypeCast,
schema::PrimitiveType_LayerNorm,
schema::PrimitiveType_L2Norm};

static const std::vector<schema::PrimitiveType> needInsertOpList = {


Loading…
Cancel
Save