Browse Source

!4334 [MS][LITE][Develop]cast support data type from float to int32

Merge pull request !4334 from chenjianping/lite_dev
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
363fbb7a79
7 changed files with 36 additions and 9 deletions
  1. +1
    -1
      mindspore/lite/src/ops/ops.h
  2. +20
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc
  3. +3
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc
  4. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc
  5. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h
  6. +6
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.cc
  7. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h

+ 1
- 1
mindspore/lite/src/ops/ops.h View File

@@ -37,7 +37,7 @@ constexpr uint32_t kNHWC_w_index = 2;
constexpr uint32_t kNHWC_c_index = 3;
constexpr uint32_t kDimension_4d = 4;

const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32};
const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32};

class Primitive {
public:


+ 20
- 5
mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc View File

@@ -65,17 +65,32 @@ int CastCPUKernel::DoCast(int thread_id) {
}

auto offset = thread_id * stride_;
auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
switch (input->data_type()) {
auto output = out_tensors_.at(0);
auto output_data = output->Data();
auto input_data_type = input->data_type();
auto output_data_type = output->data_type();
if (output_data_type != kNumberTypeFloat32) {
if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) {
Float32ToInt32(reinterpret_cast<float *>(input->Data()) + offset,
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
} else {
MS_LOG(ERROR) << "Unsupport datatype from " << input_data_type << " to " << output_data_type;
return RET_ERROR;
}
} else {
switch (input_data_type) {
case kNumberTypeUInt8:
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->Data()) + offset, output_data + offset, data_num);
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->Data()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeInt32:
Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset, output_data + offset, data_num);
Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
default:
MS_LOG(ERROR) << "Unsupport input data type " << input->data_type();
MS_LOG(ERROR) << "Unsupport input data type " << input_data_type;
return RET_ERROR;
}
}
return RET_OK;
}


+ 3
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc View File

@@ -47,8 +47,9 @@ int CropCPUKernel::CropParallelRun(int thread_id) {
auto output = out_tensors_[0];
float *input_data = reinterpret_cast<float *>(input->Data());
float *output_data = reinterpret_cast<float *>(output->Data());
Crop4D(input_data, output_data, input->shape().data(), output->shape().data(),
reinterpret_cast<CropParameter *>(op_parameter_));
auto param = reinterpret_cast<CropParameter *>(op_parameter_);
param->thread_id_ = thread_id;
Crop4D(input_data, output_data, input->shape().data(), output->shape().data(), param);
return RET_OK;
}



+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc View File

@@ -65,6 +65,10 @@ int SoftmaxCPUKernel::ReSize() {
free(sum_data_);
}
sum_data_ = reinterpret_cast<float *>(malloc(out_plane_size * in_plane_size * sizeof(float)));
if (sum_data_ == nullptr) {
MS_LOG(ERROR) << "malloc data for softmax fail!";
return RET_ERROR;
}
memset(sum_data_, 0, out_plane_size * in_plane_size * sizeof(float));
return RET_OK;
}


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h View File

@@ -27,7 +27,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel {
SoftmaxCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
: SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), sum_data_(nullptr) {}
~SoftmaxCPUKernel() override {
if (sum_data_ != nullptr) {
free(sum_data_);


+ 6
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.cc View File

@@ -40,6 +40,12 @@ void Int32ToFloat32(const int32_t *input, float *output, int number) {
}
}

void Float32ToInt32(const float *input, int32_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (int32_t)input[i];
}
}

#ifdef ENABLE_FP16
void Float32ToFloat16(const float *input, float16_t *output, int number) {
for (int i = 0; i < number; ++i) {


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h View File

@@ -32,6 +32,7 @@ void Uint8ToFloat32(const uint8_t *input, float *output, int number);
void Uint8ToInt8(const uint8_t *input, int8_t *output, int number);
void Int8ToUint8(const int8_t *input, uint8_t *output, int number);
void Int32ToFloat32(const int32_t *input, float *output, int number);
void Float32ToInt32(const float *input, int32_t *output, int number);
#ifdef ENABLE_FP16
void Float32ToFloat16(const float *input, float16_t *output, int number);
void Float16ToFloat32(const float16_t *input, float *output, int number);


Loading…
Cancel
Save