Browse Source

fix quant bug

tags/v1.1.0
cjh9368 5 years ago
parent
commit
c66dc1f58e
1 changed files with 12 additions and 0 deletions
  1. +12
    -0
      mindspore/lite/tools/common/graph_util.cc

+ 12
- 0
mindspore/lite/tools/common/graph_util.cc View File

@@ -446,6 +446,9 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
MS_ASSERT(prim != nullptr); MS_ASSERT(prim != nullptr);
preTensor->dataType = prim->srcT; preTensor->dataType = prim->srcT;
toAddTensor->dataType = prim->dstT; toAddTensor->dataType = prim->dstT;
if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
preTensor->quantParams.front()->zeroPoint += 128;
}
} }
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1; size_t toAddTensorIdx = graphT->allTensors.size() - 1;
@@ -486,6 +489,9 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
MS_ASSERT(prim != nullptr); MS_ASSERT(prim != nullptr);
preTensor->dataType = prim->srcT; preTensor->dataType = prim->srcT;
toAddTensor->dataType = prim->dstT; toAddTensor->dataType = prim->dstT;
if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
preTensor->quantParams.front()->zeroPoint += 128;
}
} }
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1; size_t toAddTensorIdx = graphT->allTensors.size() - 1;
@@ -546,6 +552,9 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
MS_ASSERT(prim != nullptr); MS_ASSERT(prim != nullptr);
postTensor->dataType = prim->srcT; postTensor->dataType = prim->srcT;
toAddTensor->dataType = prim->dstT; toAddTensor->dataType = prim->dstT;
if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
}
} }
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1; size_t toAddTensorIdx = graphT->allTensors.size() - 1;
@@ -613,6 +622,9 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
MS_ASSERT(prim != nullptr); MS_ASSERT(prim != nullptr);
postTensor->dataType = prim->srcT; postTensor->dataType = prim->srcT;
toAddTensor->dataType = prim->dstT; toAddTensor->dataType = prim->dstT;
if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
}
} }
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1; size_t toAddTensorIdx = graphT->allTensors.size() - 1;


Loading…
Cancel
Save