diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index da0660b973..09c29174dc 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -446,6 +446,9 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si MS_ASSERT(prim != nullptr); preTensor->dataType = prim->srcT; toAddTensor->dataType = prim->dstT; + if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { + preTensor->quantParams.front()->zeroPoint += 128; + } } graphT->allTensors.emplace_back(std::move(toAddTensor)); size_t toAddTensorIdx = graphT->allTensors.size() - 1; @@ -486,6 +489,9 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si MS_ASSERT(prim != nullptr); preTensor->dataType = prim->srcT; toAddTensor->dataType = prim->dstT; + if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { + preTensor->quantParams.front()->zeroPoint += 128; + } } graphT->allTensors.emplace_back(std::move(toAddTensor)); size_t toAddTensorIdx = graphT->allTensors.size() - 1; @@ -546,6 +552,9 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz MS_ASSERT(prim != nullptr); postTensor->dataType = prim->srcT; toAddTensor->dataType = prim->dstT; + if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) { + toAddTensor->quantParams.front()->zeroPoint += 128; + } } graphT->allTensors.emplace_back(std::move(toAddTensor)); size_t toAddTensorIdx = graphT->allTensors.size() - 1; @@ -613,6 +622,9 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz MS_ASSERT(prim != nullptr); postTensor->dataType = prim->srcT; toAddTensor->dataType = prim->dstT; + if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) { + toAddTensor->quantParams.front()->zeroPoint += 128; + } } graphT->allTensors.emplace_back(std::move(toAddTensor)); size_t toAddTensorIdx = graphT->allTensors.size() - 1;