Browse Source

fix quant relative

tags/v1.0.0
yankai 5 years ago
parent
commit
184e935c88
11 changed files with 134 additions and 42 deletions
  1. +19
    -0
      mindspore/lite/src/ops/primitive_c.cc
  2. +2
    -0
      mindspore/lite/src/ops/primitive_c.h
  3. +36
    -0
      mindspore/lite/src/ops/resize.cc
  4. +1
    -0
      mindspore/lite/src/ops/resize.h
  5. +43
    -0
      mindspore/lite/src/ops/strided_slice.cc
  6. +1
    -0
      mindspore/lite/src/ops/strided_slice.h
  7. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc
  8. +0
    -1
      mindspore/lite/tools/converter/converter.cc
  9. +1
    -1
      mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
  10. +28
    -37
      mindspore/lite/tools/converter/quantizer/calc_quant_param.cc
  11. +2
    -2
      mindspore/lite/tools/converter/quantizer/calc_quant_param.h

+ 19
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -232,6 +232,25 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
vecOutputQuantParam->emplace_back(quants); vecOutputQuantParam->emplace_back(quants);
} }
} }

void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector<int> *data) {
if (inputNode->isa<ValueNode>()) {
auto valNode = inputNode->cast<ValueNodePtr>();
MS_ASSERT(valNode != nullptr);
auto val = valNode->value();
MS_ASSERT(val != nullptr);
if (val->isa<ValueTuple>()) {
auto tuple = val->cast<ValueTuplePtr>();
MS_ASSERT(tuple != nullptr);
for (size_t i = 0; i < tuple->size(); i++) {
auto elem = tuple->value()[i]->cast<Int32ImmPtr>();
MS_ASSERT(elem != nullptr);
data->emplace_back(static_cast<int>(elem->value()));
}
}
}
}

schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; } schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; }


void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; } void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; }


+ 2
- 0
mindspore/lite/src/ops/primitive_c.h View File

@@ -113,6 +113,8 @@ class PrimitiveC : public mindspore::Primitive {


static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive); static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive);


void GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector<int> *data);

static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType); const schema::QuantType &quantType);
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,


+ 36
- 0
mindspore/lite/src/ops/resize.cc View File

@@ -35,6 +35,42 @@ void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {
this->primitive_->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio; this->primitive_->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio;
} }


int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Resize;
}
if (this->primitive_->value.type != schema::PrimitiveType_Resize) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::ResizeT();
if (prim.instance_name() == "ResizeNearestNeighbor") {
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
} else if (prim.instance_name() == "ResizeBilinear") {
attr->method = schema::ResizeMethod_BILINEAR;
} else {
MS_LOG(ERROR) << "wrong resize type";
return RET_ERROR;
}
std::vector<int> targetSize = GetValue<std::vector<int>>(prim.GetAttr("size"));
attr->newHeight = targetSize[0];
attr->newWidth = targetSize[1];
attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners"));

this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#else #else


int Resize::GetFormat() const { return this->primitive_->value_as_Resize()->format(); } int Resize::GetFormat() const { return this->primitive_->value_as_Resize()->format(); }


+ 1
- 0
mindspore/lite/src/ops/resize.h View File

@@ -37,6 +37,7 @@ class Resize : public PrimitiveC {
void SetNewWidth(int64_t new_width); void SetNewWidth(int64_t new_width);
void SetAlignCorners(bool align_corners); void SetAlignCorners(bool align_corners);
void SetPreserveAspectRatio(bool preserve_aspect_ratio); void SetPreserveAspectRatio(bool preserve_aspect_ratio);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
#else #else
Resize() = default; Resize() = default;




+ 43
- 0
mindspore/lite/src/ops/strided_slice.cc View File

@@ -49,6 +49,49 @@ void StridedSlice::SetIsScale(const std::vector<int> &is_scale) {
this->primitive_->value.AsStridedSlice()->isScale = is_scale; this->primitive_->value.AsStridedSlice()->isScale = is_scale;
} }


int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_StridedSlice;
}
if (this->primitive_->value.type != schema::PrimitiveType_StridedSlice) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::StridedSliceT();
attr->beginMask = GetValue<int>(prim.GetAttr("begin_mask"));
attr->endMask = GetValue<int>(prim.GetAttr("end_mask"));
attr->ellipsisMask = GetValue<int>(prim.GetAttr("ellipsis_mask"));
attr->newAxisMask = GetValue<int>(prim.GetAttr("new_axis_mask"));
attr->shrinkAxisMask = GetValue<int>(prim.GetAttr("shrink_axis_mask"));
auto inputNodeFirst = inputs[kAnfPopulaterOne];
std::vector<int> beginVec;
GetAttrDataFromInput(inputNodeFirst, &beginVec);
attr->begin = beginVec;

auto inputNodeSecond = inputs[kAnfPopulaterTwo];
std::vector<int> endVec;
GetAttrDataFromInput(inputNodeSecond, &endVec);
attr->end = endVec;

auto inputNodeThird = inputs[kAnfPopulaterThree];
std::vector<int> strideVec;
GetAttrDataFromInput(inputNodeThird, &strideVec);
attr->stride = strideVec;
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}

#else #else


int StridedSlice::GetBeginMask() const { return this->primitive_->value_as_StridedSlice()->beginMask(); } int StridedSlice::GetBeginMask() const { return this->primitive_->value_as_StridedSlice()->beginMask(); }


+ 1
- 0
mindspore/lite/src/ops/strided_slice.h View File

@@ -41,6 +41,7 @@ class StridedSlice : public PrimitiveC {
void SetEnd(const std::vector<int> &end); void SetEnd(const std::vector<int> &end);
void SetStride(const std::vector<int> &stride); void SetStride(const std::vector<int> &stride);
void SetIsScale(const std::vector<int> &is_scale); void SetIsScale(const std::vector<int> &is_scale);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
#else #else
StridedSlice() = default; StridedSlice() = default;




+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc View File

@@ -55,7 +55,7 @@ int GatherInt8CPUKernel::DoGather(int task_id) {


auto input_ptr = reinterpret_cast<int8_t *>(input_tensor->MutableData()); auto input_ptr = reinterpret_cast<int8_t *>(input_tensor->MutableData());
auto output_ptr = reinterpret_cast<int8_t *>(out_tensor->MutableData()); auto output_ptr = reinterpret_cast<int8_t *>(out_tensor->MutableData());
auto indices_ptr = reinterpret_cast<int32_t *>(out_tensor->MutableData());
auto indices_ptr = reinterpret_cast<int32_t *>(indices_tensor->MutableData());


auto in_shape = input_tensor->shape(); auto in_shape = input_tensor->shape();
int in_rank = in_shape.size(); int in_rank = in_shape.size();


+ 0
- 1
mindspore/lite/tools/converter/converter.cc View File

@@ -90,7 +90,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
MS_LOG(ERROR) << "Export to meta graph return nullptr"; MS_LOG(ERROR) << "Export to meta graph return nullptr";
return nullptr; return nullptr;
} }

// transform // transform
transform->SetGraphDef(meta_graph); transform->SetGraphDef(meta_graph);
transform->CreateQuantizer(flag); transform->CreateQuantizer(flag);


+ 1
- 1
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc View File

@@ -129,7 +129,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
MS_ASSERT(false); MS_ASSERT(false);
} }
auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
if (quantParamCalcer == nullptr) { if (quantParamCalcer == nullptr) {
MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";


+ 28
- 37
mindspore/lite/tools/converter/quantizer/calc_quant_param.cc View File

@@ -439,61 +439,52 @@ class CalcActivation : public QuantParamCalcer {
} }
} }
}; };
QuantParamCalcRegister::~QuantParamCalcRegister() {
for (auto ite : _registerMap) {
if (ite.second != nullptr) {
delete ite.second;
ite.second = nullptr;
}
}
}
QuantParamCalcRegister::~QuantParamCalcRegister() {}

QuantParamCalcRegister::QuantParamCalcRegister() { QuantParamCalcRegister::QuantParamCalcRegister() {
bool hasError = false; bool hasError = false;
std::unique_ptr<QuantParamCalcer> baseCalcer(new (std::nothrow) QuantParamCalcer());
std::shared_ptr<QuantParamCalcer> baseCalcer = std::make_shared<QuantParamCalcer>();
if (baseCalcer == nullptr) { if (baseCalcer == nullptr) {
MS_LOG(ERROR) << "new QuantParamCalcer failed"; MS_LOG(ERROR) << "new QuantParamCalcer failed";
hasError = true; hasError = true;
} }
std::unique_ptr<CommonCalcer> commonCalcer(new (std::nothrow) CommonCalcer());
std::shared_ptr<QuantParamCalcer> commonCalcer = std::make_shared<CommonCalcer>();
if (commonCalcer == nullptr) { if (commonCalcer == nullptr) {
MS_LOG(ERROR) << "new commonCalcer failed"; MS_LOG(ERROR) << "new commonCalcer failed";
hasError = true; hasError = true;
} }


std::unique_ptr<LinearCalcer> linearCalcer(new (std::nothrow) LinearCalcer());
std::shared_ptr<QuantParamCalcer> linearCalcer = std::make_shared<LinearCalcer>();
if (linearCalcer == nullptr) { if (linearCalcer == nullptr) {
MS_LOG(ERROR) << "new linearCalcer failed"; MS_LOG(ERROR) << "new linearCalcer failed";
hasError = true; hasError = true;
} }
if (!hasError) { if (!hasError) {
_registerMap[schema::PrimitiveType_Concat] = new CalcConcat();
_registerMap[schema::PrimitiveType_Activation] = new CalcActivation();
_registerMap[schema::PrimitiveType_Add] = new CalcAdd();
_registerMap[schema::PrimitiveType_Mul] = commonCalcer.get();
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer.get();
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer.get();
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer.get();
_registerMap[schema::PrimitiveType_Resize] = linearCalcer.get();
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer.get();
_registerMap[schema::PrimitiveType_Shape] = linearCalcer.get();
_registerMap[schema::PrimitiveType_SoftMax] = new CalcToSet(0, 1);
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer.get();
_registerMap[schema::PrimitiveType_RealDiv] = new CalcRealDiv();
_registerMap[schema::PrimitiveType_Reduce] = commonCalcer.get();
_registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer.get();
_registerMap[schema::PrimitiveType_Mean] = linearCalcer.get();
_registerMap[schema::PrimitiveType_Transpose] = linearCalcer.get();
_registerMap[schema::PrimitiveType_MatMul] = commonCalcer.get();
_registerMap[schema::PrimitiveType_FullConnection] = commonCalcer.get();
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer.get();
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer.get();
_registerMap[schema::PrimitiveType_Concat] = std::make_shared<CalcConcat>();
_registerMap[schema::PrimitiveType_Activation] = std::make_shared<CalcActivation>();
_registerMap[schema::PrimitiveType_Add] = std::make_shared<CalcAdd>();
_registerMap[schema::PrimitiveType_Mul] = commonCalcer;
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer;
_registerMap[schema::PrimitiveType_Resize] = linearCalcer;
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer;
_registerMap[schema::PrimitiveType_Shape] = linearCalcer;
_registerMap[schema::PrimitiveType_SoftMax] = std::make_shared<CalcToSet>(0, 1);
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer;
_registerMap[schema::PrimitiveType_RealDiv] = std::make_shared<CalcRealDiv>();
_registerMap[schema::PrimitiveType_Reduce] = commonCalcer;
_registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer;
_registerMap[schema::PrimitiveType_Mean] = linearCalcer;
_registerMap[schema::PrimitiveType_Transpose] = linearCalcer;
_registerMap[schema::PrimitiveType_MatMul] = commonCalcer;
_registerMap[schema::PrimitiveType_FullConnection] = commonCalcer;
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer;
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer;
// detection_postprocess op's quant param will not infer only fetch from preNode or postNode // detection_postprocess op's quant param will not infer only fetch from preNode or postNode
// because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float.
// if quantTransNode is inserted after detection_postprocess node, there will be some errors // if quantTransNode is inserted after detection_postprocess node, there will be some errors
_registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer.get();
baseCalcer.release();
linearCalcer.release();
commonCalcer.release();
_registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer;
} }
} }


@@ -502,7 +493,7 @@ QuantParamCalcRegister *QuantParamCalcRegister::GetInstance() {
return &instance; return &instance;
} }


QuantParamCalcer *QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) {
std::shared_ptr<QuantParamCalcer> QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) {
auto it = _registerMap.find(opType); auto it = _registerMap.find(opType);
if (it != _registerMap.end()) { if (it != _registerMap.end()) {
return it->second; return it->second;


+ 2
- 2
mindspore/lite/tools/converter/quantizer/calc_quant_param.h View File

@@ -56,12 +56,12 @@ class LinearCalcer : public QuantParamCalcer {
class QuantParamCalcRegister { class QuantParamCalcRegister {
public: public:
virtual ~QuantParamCalcRegister(); virtual ~QuantParamCalcRegister();
QuantParamCalcer *GetQuantParamCalcer(schema::PrimitiveType opType);
std::shared_ptr<QuantParamCalcer> GetQuantParamCalcer(schema::PrimitiveType opType);
static QuantParamCalcRegister *GetInstance(); static QuantParamCalcRegister *GetInstance();


private: private:
QuantParamCalcRegister(); QuantParamCalcRegister();
std::unordered_map<schema::PrimitiveType, QuantParamCalcer *> _registerMap;
std::unordered_map<schema::PrimitiveType, std::shared_ptr<QuantParamCalcer>> _registerMap;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


Loading…
Cancel
Save