Browse Source

add boolTofloat func for cast operator

tags/v1.1.0
xuanyue 5 years ago
parent
commit
dec4785b65
4 changed files with 10 additions and 3 deletions
  1. +6
    -0
      mindspore/lite/nnacl/fp32/cast.c
  2. +1
    -0
      mindspore/lite/nnacl/fp32/cast.h
  3. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc
  4. +0
    -3
      mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc

+ 6
- 0
mindspore/lite/nnacl/fp32/cast.c View File

@@ -17,6 +17,12 @@
#include "nnacl/fp32/cast.h"
#include "nnacl/fp32/common_func.h"

void BoolToFloat32(const bool *input, float *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (float)input[i];
}
}

void Uint8ToFloat32(const uint8_t *input, float *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (float)input[i];


+ 1
- 0
mindspore/lite/nnacl/fp32/cast.h View File

@@ -31,6 +31,7 @@ typedef struct CastParameter {
#ifdef __cplusplus
extern "C" {
#endif
void BoolToFloat32(const bool *input, float *output, int number);
void Uint8ToFloat32(const uint8_t *input, float *output, int number);
void Uint8ToInt8(const uint8_t *input, int8_t *output, int number);
void Int8ToUint8(const int8_t *input, uint8_t *output, int number);


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc View File

@@ -82,6 +82,9 @@ int CastCPUKernel::DoCast(int thread_id) {
}
} else {
switch (input_data_type) {
case kNumberTypeBool:
BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
case kNumberTypeUInt8:
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);


+ 0
- 3
mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc View File

@@ -46,9 +46,6 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
if (attr->srcT == TypeId::kNumberTypeBool) {
attr->srcT = TypeId::kNumberTypeUInt8;
}
const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";


Loading…
Cancel
Save