Browse Source

support fp16 model

tags/v0.7.0-beta
cjh9368 5 years ago
parent
commit
f3b794bc26
15 changed files with 273 additions and 21 deletions
  1. +1
    -0
      mindspore/lite/schema/model.fbs
  2. +5
    -0
      mindspore/lite/schema/ops.fbs
  3. +1
    -1
      mindspore/lite/src/ops/cast.cc
  4. +1
    -1
      mindspore/lite/src/ops/ops.h
  5. +6
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc
  6. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.c
  7. +7
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.c
  8. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h
  9. +121
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.c
  10. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h
  11. +1
    -15
      mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h
  12. +68
    -0
      mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc
  13. +39
    -0
      mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h
  14. +6
    -4
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
  15. +12
    -0
      mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h

+ 1
- 0
mindspore/lite/schema/model.fbs View File

@@ -90,6 +90,7 @@ union PrimitiveType {
Rsqrt,
ExpandDims,
Tile,
Fp16Cast,
Cast,
Shape,
Nchw2Nhwc,


+ 5
- 0
mindspore/lite/schema/ops.fbs View File

@@ -576,6 +576,11 @@ table Cast {
dstT: int;
}

table Fp16Cast {
srcT: int;
dstT: int;
}

table QuantDTypeCast {
srcT: int;
dstT: int;


+ 1
- 1
mindspore/lite/src/ops/cast.cc View File

@@ -46,7 +46,7 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->set_data_type(TypeId::kNumberTypeFloat32);
return RET_OK;
}
} // namespace mindspore::lite

+ 1
- 1
mindspore/lite/src/ops/ops.h View File

@@ -37,7 +37,7 @@ constexpr uint32_t kNHWC_w_index = 2;
constexpr uint32_t kNHWC_c_index = 3;
constexpr uint32_t kDimension_4d = 4;

const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32};
const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeFloat16};

class Primitive {
public:


+ 6
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc View File

@@ -27,6 +27,7 @@ using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Cast;
using mindspore::schema::PrimitiveType_Fp16Cast;

namespace mindspore::kernel {
namespace {
@@ -87,6 +88,10 @@ int CastCPUKernel::DoCast(int thread_id) {
Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeFloat16:
Fp16ToFloat32(reinterpret_cast<int16_t *>(input->Data()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
default:
MS_LOG(ERROR) << "Unsupport input data type " << input_data_type;
return RET_ERROR;
@@ -139,4 +144,5 @@ kernel::LiteKernel *CpuCastFp32KernelCreator(const std::vector<lite::tensor::Ten
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cast, CpuCastFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Fp16Cast, CpuCastFp32KernelCreator)
} // namespace mindspore::kernel

+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.c View File

@@ -265,3 +265,4 @@ void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane
}
}
}


+ 7
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.c View File

@@ -15,6 +15,7 @@
*/

#include "nnacl/fp32/cast.h"
#include "nnacl/fp32/common_func.h"

void Uint8ToFloat32(const uint8_t *input, float *output, int number) {
for (int i = 0; i < number; ++i) {
@@ -40,6 +41,12 @@ void Int32ToFloat32(const int32_t *input, float *output, int number) {
}
}

void Fp16ToFloat32(const int16_t *input, float *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = ShortToFloat32(input[i]);
}
}

void Float32ToInt32(const float *input, int32_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (int32_t)input[i];


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h View File

@@ -35,6 +35,7 @@ void Uint8ToFloat32(const uint8_t *input, float *output, int number);
void Uint8ToInt8(const uint8_t *input, int8_t *output, int number);
void Int8ToUint8(const int8_t *input, uint8_t *output, int number);
void Int32ToFloat32(const int32_t *input, float *output, int number);
void Fp16ToFloat32(const int16_t *input, float *output, int number);
void Float32ToInt32(const float *input, int32_t *output, int number);
#ifdef __cplusplus
}


+ 121
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.c View File

@@ -113,3 +113,124 @@ void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bi
PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C8NUM);
return;
}

static const unsigned int FP32_BIT_SIZE = 32;
static const unsigned int FP32_EXPONENT_BIAS = 127;
static const unsigned int FP32_SIGNIFICAND = 23;

static const unsigned int FP32_EXPONENT_MAX = 255;

static const unsigned int FP16_BIT_SIZE = 16;
static const unsigned int FP16_EXPONENT_BIAS = 15;
static const unsigned int FP16_SIGNIFICAND = 10;

static const int FP16_EXPONENT_MAX = 30;
static const int FP16_EXPONENT_MIN = -10;

float ShortToFloat32(int16_t srcValue) {
uint16_t expHalf16 = srcValue & 0x7C00;
int exp1 = (int)(expHalf16);
uint16_t mantissa16 = srcValue & 0x03FF;
int mantissa1 = (int)(mantissa16);
int sign = (int)(srcValue & 0x8000);
sign = sign << FP16_BIT_SIZE;

// nan or inf
if (expHalf16 == 0x7C00) {
// nan
if (mantissa16 > 0) {
int res = (0x7FC00000 | sign);
int *iRes = &res;
auto fres = (float)(*iRes);
return fres;
}
// inf
int res = (0x7F800000 | sign);
int *iRes = &res;
auto fres = (float)(*iRes);
return fres;
}
if (expHalf16 != 0) {
exp1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS) << FP16_SIGNIFICAND); // exponents converted to float32 bias
int res = (exp1 | mantissa1);
res = res << (FP32_SIGNIFICAND - FP16_SIGNIFICAND);
res = (res | sign);
int *iRes = &res;
auto fres = (float)(*iRes);
return fres;
}

int xmm1 = exp1 > (1 << FP16_SIGNIFICAND) ? exp1 : (1 << FP16_SIGNIFICAND);
xmm1 = (xmm1 << (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
xmm1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS - FP16_SIGNIFICAND)
<< FP32_SIGNIFICAND); // add the bias difference to xmm1
xmm1 = xmm1 | sign; // Combine with the sign mask

auto res = (float)(mantissa1); // Convert mantissa to float
int *ixmm1 = NULL;
ixmm1 = &xmm1;
res *= (float)(*ixmm1);

return res;
}

// __gnu_f2h_ieee
int16_t Float32ToShort(float srcValue) {
float *psrcValue = NULL;
psrcValue = &srcValue;
auto srcValueBit = (unsigned int)(*psrcValue);
int sign = srcValueBit >> (FP32_BIT_SIZE - 1);
int mantissa = srcValueBit & 0x007FFFFF;
// exponent
int exp = ((srcValueBit & 0x7F800000) >> FP32_SIGNIFICAND) + FP16_EXPONENT_BIAS - FP32_EXPONENT_BIAS;
int16_t res;
if (exp > 0 && exp < FP16_EXPONENT_MAX) {
// use rte rounding mode, round the significand, combine sign, exponent and significand into a short.
res = (sign << (FP16_BIT_SIZE - 1)) | (exp << FP16_SIGNIFICAND) |
((mantissa + 0x00001000) >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
} else if (srcValueBit == 0) {
res = 0;
} else {
if (exp <= 0) {
if (exp < FP16_EXPONENT_MIN) {
// value is less than min half float point
res = 0;
} else {
// normalized single, magnitude is less than min normal half float point.
mantissa = (mantissa | 0x00800000) >> (1 - exp);
// round to nearest
if ((mantissa & 0x00001000) > 0) {
mantissa = mantissa + 0x00002000;
}
// combine sign & mantissa (exp is zero to get denormalized number)
res = (sign << FP16_EXPONENT_BIAS) | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
}
} else if (exp == (FP32_EXPONENT_MAX - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS)) {
if (mantissa == 0) {
// input float is infinity, return infinity half
res = (sign << FP16_EXPONENT_BIAS) | 0x7C00;
} else {
// input float is NaN, return half NaN
res = (sign << FP16_EXPONENT_BIAS) | 0x7C00 | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
}
} else {
// exp > 0, normalized single, round to nearest
if ((mantissa & 0x00001000) > 0) {
mantissa = mantissa + 0x00002000;
if ((mantissa & 0x00800000) > 0) {
mantissa = 0;
exp = exp + 1;
}
}
if (exp > FP16_EXPONENT_MAX) {
// exponent overflow - return infinity half
res = (sign << FP16_EXPONENT_BIAS) | 0x7C00;
} else {
// combine sign, exp and mantissa into normalized half
res = (sign << FP16_EXPONENT_BIAS) | (exp << FP16_SIGNIFICAND) |
(mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
}
}
}
return res;
}

+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h View File

@@ -37,6 +37,9 @@ void MatrixSub(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stri
size_t row, size_t col);
void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col,
size_t c_stride, size_t x_stride);
int16_t Float32ToShort(float srcValue);

float ShortToFloat32(int16_t srcValue);

#ifdef ENABLE_ARM
void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,


+ 1
- 15
mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h View File

@@ -14,12 +14,11 @@
* limitations under the License.
*/

#ifndef LITE_TFLITE_CAST_PARSER_H
#ifndef LITE_TFLITE_CAST_PARSER_
#define LITE_TFLITE_CAST_PARSER_H

#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

@@ -35,19 +34,6 @@ class TfliteCastParser : public TfliteNodeParser {
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;

private:
std::map<int, TypeId> dtype_map = {
{tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64},
{tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32},
{tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16},
{tflite::TensorType_INT64, TypeId::kNumberTypeInt64},
{tflite::TensorType_INT32, TypeId::kNumberTypeInt32},
{tflite::TensorType_INT16, TypeId::kNumberTypeInt16},
{tflite::TensorType_INT8, TypeId::kNumberTypeInt8},
{tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8},
{tflite::TensorType_BOOL, TypeId::kNumberTypeBool},
};
};
} // namespace lite
} // namespace mindspore


+ 68
- 0
mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc View File

@@ -0,0 +1,68 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_dequantize_parser.h"
#include <vector>
#include <memory>
#include "tools/common/node_util.h"

namespace mindspore {
namespace lite {
STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT);

// get the dequantize input tensor
const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null";
return RET_NULL_PTR;
}
attr->srcT = dtype_map[in_tensor->type];

const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR;
}
attr->dstT = dtype_map[out_tensor->type];
std::vector<tflite::TensorT *> weight_tensors{in_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}

op->primitive->value.type = schema::PrimitiveType_Fp16Cast;
op->primitive->value.value = attr.release();
return 0;
}

TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser());
} // namespace lite
} // namespace mindspore

+ 39
- 0
mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H
#define LITE_TFLITE_DEQUANTIZE_PARSER_H

#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"

namespace mindspore {
namespace lite {
class TfliteDequantizeParser : public TfliteNodeParser {
public:
TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_TFLITE_DEQUANTIZE_PARSER_H

+ 6
- 4
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc View File

@@ -208,10 +208,12 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_
return RET_ERROR;
}

status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed";
return RET_ERROR;
if (quantType != schema::QuantType_QUANT_NONE) {
status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed";
return RET_ERROR;
}
}

subGraph->nodes.emplace_back(std::move(op));


+ 12
- 0
mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h View File

@@ -19,6 +19,7 @@

#include <string>
#include <vector>
#include <map>
#include <memory>
#include "utils/log_adapter.h"
#include "schema/inner/model_generated.h"
@@ -126,6 +127,17 @@ class TfliteNodeParser {

protected:
const std::string &name;
std::map<int, TypeId> dtype_map = {
{tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64},
{tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32},
{tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16},
{tflite::TensorType_INT64, TypeId::kNumberTypeInt64},
{tflite::TensorType_INT32, TypeId::kNumberTypeInt32},
{tflite::TensorType_INT16, TypeId::kNumberTypeInt16},
{tflite::TensorType_INT8, TypeId::kNumberTypeInt8},
{tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8},
{tflite::TensorType_BOOL, TypeId::kNumberTypeBool},
};
};
} // namespace lite
} // namespace mindspore


Loading…
Cancel
Save