Browse Source

int32 tensors don't insert dtype_cast op

tags/v1.0.0
cjh9368 5 years ago
parent
commit
c5db8e0a32
6 changed files with 19 additions and 18 deletions
  1. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc
  2. +1
    -1
      mindspore/lite/tools/common/node_util.cc
  3. +1
    -1
      mindspore/lite/tools/common/node_util.h
  4. +12
    -8
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
  5. +1
    -4
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
  6. +2
    -2
      mindspore/lite/tools/converter/quantizer/aware_quantizer.cc

+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc View File

@@ -63,7 +63,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
if (!weight_tensor->GetQuantParams().empty()) {
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
@@ -91,7 +91,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty()) {
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 1
- 1
mindspore/lite/tools/common/node_util.cc View File

@@ -93,7 +93,7 @@ std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInp

std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; }

std::vector<schema::PrimitiveType> GetUint8OpList() { return int8OpList; }
std::vector<schema::PrimitiveType> GetInt8OpList() { return int8OpList; }

STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {


+ 1
- 1
mindspore/lite/tools/common/node_util.h View File

@@ -42,7 +42,7 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList();

std::vector<schema::PrimitiveType> GetUint8NhwcOpList();

std::vector<schema::PrimitiveType> GetUint8OpList();
std::vector<schema::PrimitiveType> GetInt8OpList();

class NodeUtils {
public:


+ 12
- 8
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc View File

@@ -51,13 +51,7 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {

STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// modify inputTensor first
auto &graphInIdxes = graph->inputIndex;
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graph->allTensors.size() > graphInIdx);
auto &graphInTensor = graph->allTensors.at(graphInIdx);
graphInTensor->dataType = TypeId::kNumberTypeInt8;
}

if (this->inputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK;
@@ -70,7 +64,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graphInIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphInIdx);
if (tensor->dims.size() != kNHWCDimNumber) {
if (tensor->dims.size() != kNHWCDimNumber || tensor->dataType != kNumberTypeInt8) {
continue;
}

@@ -137,7 +131,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert transNode before and after existNode
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
continue;
}
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) {
@@ -157,10 +151,16 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
if (preTensor->dataType == TypeId::kNumberTypeInt || preTensor->dataType == TypeId::kNumberTypeInt32) {
continue;
}
auto &graphInIdxes = graph->inputIndex;
if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
continue;
}
if (IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
@@ -170,6 +170,10 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {

if (needInsertPost) {
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i));
if (postTensor->dataType == TypeId::kNumberTypeInt || postTensor->dataType == TypeId::kNumberTypeInt32) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";


+ 1
- 4
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc View File

@@ -79,6 +79,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor
// change quant param min to 0 to fit ms-lite ops
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
tensor->dataType = TypeId::kNumberTypeInt8;
}

if (!tflite_tensor->quantization->min.empty()) {
@@ -164,11 +165,7 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT>
MS_LOG(ERROR) << "obtain const tensor failed";
return status;
}
} else if (quantType == QuantType_AwareTraining && tensor->dataType == TypeId::kNumberTypeUInt8) {
// set in/out tensor to int8 to fit ms-lite op
tensor->dataType = TypeId::kNumberTypeInt8;
}

// set tensor attr
if (isInput || isConst) {
tensor->nodeType = schema::NodeType::NodeType_ValueNode;


+ 2
- 2
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc View File

@@ -145,7 +145,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
STATUS AwareQuantizer::DoQuantize() {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
if (!IsContain(GetInt8OpList(), GetCNodeTType(*node))) {
continue;
}
if (node->quantType != schema::QuantType_AwareTraining) {
@@ -388,7 +388,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
}
}

if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*node))) {
node->quantType = schema::QuantType_AwareTraining;
} else {
node->quantType = schema::QuantType_QUANT_NONE;


Loading…
Cancel
Save