Browse Source

quant_dtype_cast support int8 to int8

tags/v1.1.0
cjh9368 5 years ago
parent
commit
f26b027973
2 changed files with 44 additions and 21 deletions
  1. +38
    -18
      mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc
  2. +6
    -3
      mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h

+ 38
- 18
mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc View File

@@ -46,30 +46,33 @@ int QuantDTypeCastCPUKernel::Init() {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
inverse_ = false;
} else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeFloat32) {
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
inverse_ = true;
} else if (param->srcT == kNumberTypeUInt8 && param->dstT == kNumberTypeInt8) {
if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
inverse_ = false;
} else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeInt8) {
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
} else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeUInt8) {
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeUInt8) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
inverse_ = true;
} else {
MS_LOG(ERROR) << "param data type not supported:"
<< " src: " << param->srcT << " dst: " << param->dstT;
return RET_PARAM_INVALID;
}
src_dtype = param->srcT;
dst_dtype = param->dstT;

if (!InferShapeDone()) {
return RET_OK;
@@ -97,20 +100,25 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
}
auto quant_arg = !out_tensors_.front()->GetQuantParams().empty() ? out_tensors_.front()->GetQuantParams().front()
: in_tensors_.front()->GetQuantParams().front();
int ret;
if (uint8_ptr_ == nullptr) {
if (inverse_) {
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,
int ret = RET_OK;
if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) {
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) {
ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
} else {
ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
}
} else {
if (inverse_) {
ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread);
} else {
ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) {
ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeInt8) {
ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeInt8) {
auto input_quant_arg = in_tensors_.front()->GetQuantParams().front();
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, num_unit_thread,
input_quant_arg.scale, input_quant_arg.zeroPoint);
if (ret) {
auto output_quant_arg = out_tensors_.front()->GetQuantParams().front();
ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset,
output_quant_arg.scale, output_quant_arg.zeroPoint, num_unit_thread);
}
}

@@ -154,14 +162,26 @@ int QuantDTypeCastCPUKernel::Run() {
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) {
uint8_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_[0]->data_c());
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c());
} else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) {
int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c());
int8_out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c());
float32_ptr_ = new float[in_tensors_[0]->ElementsNum()];
}

auto ret = ParallelLaunch(this->context_->thread_pool_, QuantDTypeCastRun, this, thread_n_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale error error_code[" << ret << "]";
if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) {
delete (float32_ptr_);
}
return RET_ERROR;
}

if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) {
delete (float32_ptr_);
}
return RET_OK;
}



+ 6
- 3
mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h View File

@@ -39,10 +39,13 @@ class QuantDTypeCastCPUKernel : public LiteKernel {
int thread_n_num_;
int thread_n_stride_;
int num_unit_;
int8_t *int8_ptr_;
int8_t *int8_ptr_ = nullptr;
int8_t *int8_out_ptr_ = nullptr;
uint8_t *uint8_ptr_ = nullptr;
float *float32_ptr_;
bool inverse_;
float *float32_ptr_ = nullptr;

int32_t src_dtype;
int32_t dst_dtype;
};
} // namespace mindspore::kernel



Loading…
Cancel
Save