Browse Source

bug fix

tags/v1.1.0
cjh9368 5 years ago
parent
commit
c688a25647
8 changed files with 79 additions and 38 deletions
  1. +2
    -0
      mindspore/lite/src/kernel_registry.cc
  2. +22
    -0
      mindspore/lite/tools/converter/converter_context.h
  3. +5
    -3
      mindspore/lite/tools/converter/graphdef_transform.cc
  4. +13
    -16
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
  5. +35
    -10
      mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc
  6. +1
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc
  7. +0
    -7
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
  8. +1
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc

+ 2
- 0
mindspore/lite/src/kernel_registry.cc View File

@@ -120,6 +120,8 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_te
kernel->set_desc(key); kernel->set_desc(key);
} }
return kernel; return kernel;
} else {
free(parameter);
} }
return nullptr; return nullptr;
} }


+ 22
- 0
mindspore/lite/tools/converter/converter_context.h View File

@@ -19,8 +19,10 @@


#include <string> #include <string>
#include <set> #include <set>
#include <map>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "ir/dtype/type_id.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@@ -68,6 +70,26 @@ class NoSupportOp {
std::set<std::string> noSupportOps; std::set<std::string> noSupportOps;
std::string fmkType; std::string fmkType;
}; };

class TensorDataType {
public:
~TensorDataType() = default;
static TensorDataType *GetInstance() {
static TensorDataType tensorDataType;
return &tensorDataType;
}
void UpdateTensorType(int32_t index, int32_t type) { tensorDataTypeMap[index] = type; }
int32_t GetTensorType(int32_t index) const {
if (tensorDataTypeMap.find(index) == tensorDataTypeMap.end()) {
return TypeId::kTypeUnknown;
}
return tensorDataTypeMap.at(index);
}

private:
TensorDataType() {}
std::map<int32_t, int32_t> tensorDataTypeMap;
};
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_RETURN_CODE_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_RETURN_CODE_H

+ 5
- 3
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -132,9 +132,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {


// do quantization // do quantization
{ {
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
status = fusionOptimizer.Run(graphDefT);
Optimizer tensorQuantOptimizer;
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass());
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
status = tensorQuantOptimizer.Run(graphDefT);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!"; MS_LOG(ERROR) << "DoQuantize failed!";
return status; return status;


+ 13
- 16
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc View File

@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <set> #include <set>
#include "tools/common/node_util.h" #include "tools/common/node_util.h"
#include "tools/converter/converter_context.h"
#include "src/common/common.h" #include "src/common/common.h"
#include "src/common/utils.h" #include "src/common/utils.h"


@@ -52,12 +53,8 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
auto &graphInIdxes = graph->inputIndex; auto &graphInIdxes = graph->inputIndex;

if (this->inputDataDType == TypeId::kTypeUnknown) {
return RET_OK;
}
if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 && if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 &&
this->inputDataDType != TypeId::kNumberTypeInt8) {
this->inputDataDType != TypeId::kNumberTypeInt8 && this->inputDataDType != TypeId::kTypeUnknown) {
MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType;
return RET_ERROR; return RET_ERROR;
} }
@@ -67,7 +64,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue; continue;
} }

int32_t tensorDataType = this->inputDataDType != TypeId::kTypeUnknown
? this->inputDataDType
: TensorDataType::GetInstance()->GetTensorType(graphInIdx);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto nodeName = (*iter)->name; auto nodeName = (*iter)->name;
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) {
@@ -75,9 +74,8 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
STATUS status = RET_OK; STATUS status = RET_OK;


// insert dtype cast node between input tensor and input node // insert dtype cast node between input tensor and input node
if (this->inputDataDType != tensor->dataType) {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, this->inputDataDType, tensor->dataType,
&status);
if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, tensorDataType, tensor->dataType, &status);
} }


if (status != RET_OK) { if (status != RET_OK) {
@@ -93,11 +91,8 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {


STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
if (outputDataDType == TypeId::kTypeUnknown) {
return RET_OK;
}
if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 && if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 &&
this->outputDataDType != TypeId::kNumberTypeInt8) {
this->outputDataDType != TypeId::kNumberTypeInt8 && this->outputDataDType != TypeId::kTypeUnknown) {
MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType;
return RET_ERROR; return RET_ERROR;
} }
@@ -108,6 +103,9 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue; continue;
} }
int32_t tensorDataType = this->outputDataDType != TypeId::kTypeUnknown
? this->inputDataDType
: TensorDataType::GetInstance()->GetTensorType(graphOutIdx);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto nodeName = (*iter)->name; auto nodeName = (*iter)->name;
MS_ASSERT(node != nullptr); MS_ASSERT(node != nullptr);
@@ -115,9 +113,8 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) {
// insert transNode // insert transNode
STATUS status = RET_OK; STATUS status = RET_OK;
if (this->outputDataDType != tensor->dataType) {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, this->outputDataDType,
&status);
if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensorDataType, &status);
} }
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed";


+ 35
- 10
mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc View File

@@ -17,23 +17,42 @@
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/quantizer/quantize_util.h" #include "tools/converter/quantizer/quantize_util.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"


namespace mindspore::lite { namespace mindspore::lite {
STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
for (auto &node : graph->nodes) {
if (node->primitive->value.type == PrimitiveType_QuantDTypeCast) {
auto attr = node->primitive->value.AsQuantDTypeCast();
auto &inputTensor = graph->allTensors.at(node->inputIndex.front());
inputTensor->dataType = attr->srcT;
auto &outputTensor = graph->allTensors.at(node->outputIndex.front());
outputTensor->dataType = attr->dstT;

if (attr->srcT == TypeId::kNumberTypeUInt8) {
attr->srcT = TypeId::kNumberTypeInt8;
}
if (attr->dstT == TypeId::kNumberTypeUInt8) {
attr->dstT = TypeId::kNumberTypeInt8;
}
}
}
int index = -1;
for (auto &tensor : graph->allTensors) { for (auto &tensor : graph->allTensors) {
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) {
index++;
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue; continue;
} }
if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat &&
tensor->dataType != TypeId::kNumberTypeUInt8) {
tensor->dataType != TypeId::kNumberTypeUInt8 && tensor->dataType != TypeId::kTypeUnknown) {
continue; continue;
} }
// perlayer // perlayer
if (tensor->quantParams.size() == 1) { if (tensor->quantParams.size() == 1) {
auto &quantParam = tensor->quantParams.front(); auto &quantParam = tensor->quantParams.front();
size_t wShapeSize = GetShapeSize(*(tensor.get()));
size_t wShapeSize = tensor->data.empty() ? 0 : GetShapeSize(*(tensor.get()));
void *oriWeightData = tensor->data.data(); void *oriWeightData = tensor->data.data();
if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { if (quantParam->dstDtype == TypeId::kNumberTypeInt8) {
std::vector<int8_t> qDatas(wShapeSize); std::vector<int8_t> qDatas(wShapeSize);
@@ -41,6 +60,9 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
if (tensor->dataType == TypeId::kNumberTypeFloat || if (tensor->dataType == TypeId::kNumberTypeFloat ||
tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
auto *weightData = static_cast<float *>(oriWeightData); auto *weightData = static_cast<float *>(oriWeightData);
if (weightData == nullptr) {
continue;
}
for (size_t j = 0; j < wShapeSize; j++) { for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
} }
@@ -52,15 +74,18 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
weightQauntParam->zeroPoint -= 128; weightQauntParam->zeroPoint -= 128;
tensor->quantParams.clear(); tensor->quantParams.clear();
tensor->quantParams.emplace_back(weightQauntParam.release()); tensor->quantParams.emplace_back(weightQauntParam.release());
TensorDataType::GetInstance()->UpdateTensorType(index, TypeId::kNumberTypeUInt8);
} }
tensor->dataType = TypeId::kNumberTypeInt8; tensor->dataType = TypeId::kNumberTypeInt8;
tensor->data.clear();
tensor->data.resize(wShapeSize * sizeof(int8_t));
auto ret =
memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
return RET_ERROR;
if (!tensor->data.empty()) {
tensor->data.clear();
tensor->data.resize(wShapeSize * sizeof(int8_t));
auto ret =
memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
return RET_ERROR;
}
} }
} else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) {
// quant bias data // quant bias data


+ 1
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc View File

@@ -53,7 +53,7 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info,
MS_LOG(ERROR) << "new op failed"; MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->srcT = kNumberTypeInt8;
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type); attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;


+ 0
- 7
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc View File

@@ -76,12 +76,6 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[i]; quant_param->zeroPoint = tflite_tensor->quantization->zero_point[i];
} }


// change quant param min to 0 to fit ms-lite ops
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
tensor->dataType = TypeId::kNumberTypeInt8;
}

if (!tflite_tensor->quantization->min.empty()) { if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[i]; quant_param->min = tflite_tensor->quantization->min[i];
} }
@@ -127,7 +121,6 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
} }
continue; continue;
} }

sub_graph->nodes.emplace_back(op.release()); sub_graph->nodes.emplace_back(op.release());
opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get();
tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get();


+ 1
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc View File

@@ -53,7 +53,7 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->srcT = GetTfliteDataType(in_tensor->type); attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = kNumberTypeInt8;
attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
} else { } else {


Loading…
Cancel
Save