diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 33dfd55ff7..8cd54de178 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -232,6 +232,25 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, vecOutputQuantParam->emplace_back(quants); } } + +void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector *data) { + if (inputNode->isa()) { + auto valNode = inputNode->cast(); + MS_ASSERT(valNode != nullptr); + auto val = valNode->value(); + MS_ASSERT(val != nullptr); + if (val->isa()) { + auto tuple = val->cast(); + MS_ASSERT(tuple != nullptr); + for (size_t i = 0; i < tuple->size(); i++) { + auto elem = tuple->value()[i]->cast(); + MS_ASSERT(elem != nullptr); + data->emplace_back(static_cast(elem->value())); + } + } + } +} + schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; } void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; } diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 2163d1f057..80728033c5 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -113,6 +113,8 @@ class PrimitiveC : public mindspore::Primitive { static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive); + void GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector *data); + static std::shared_ptr Create(const Primitive &prim, const std::vector &inputs, const schema::QuantType &quantType); void PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index 6f1c4912d2..6070a28874 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -35,6 +35,42 @@ void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) { this->primitive_->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio; } +int Resize::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Resize; + } + if (this->primitive_->value.type != schema::PrimitiveType_Resize) { + MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::ResizeT(); + if (prim.instance_name() == "ResizeNearestNeighbor") { + attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR; + } else if (prim.instance_name() == "ResizeBilinear") { + attr->method = schema::ResizeMethod_BILINEAR; + } else { + MS_LOG(ERROR) << "wrong resize type"; + return RET_ERROR; + } + std::vector targetSize = GetValue>(prim.GetAttr("size")); + attr->newHeight = targetSize[0]; + attr->newWidth = targetSize[1]; + attr->alignCorners = GetValue(prim.GetAttr("align_corners")); + + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} #else int Resize::GetFormat() const { return this->primitive_->value_as_Resize()->format(); } diff --git a/mindspore/lite/src/ops/resize.h b/mindspore/lite/src/ops/resize.h index 84b22debac..9547bf58d9 100644 --- a/mindspore/lite/src/ops/resize.h +++ b/mindspore/lite/src/ops/resize.h @@ -37,6 +37,7 @@ class Resize : public PrimitiveC { void SetNewWidth(int64_t new_width); void SetAlignCorners(bool align_corners); void SetPreserveAspectRatio(bool preserve_aspect_ratio); + int UnPackAttr(const Primitive &prim, const std::vector &inputs); #else Resize() = default; diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index bd56848c12..9e0dafe8e6 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -49,6 +49,49 @@ void StridedSlice::SetIsScale(const std::vector &is_scale) { this->primitive_->value.AsStridedSlice()->isScale = is_scale; } +int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_StridedSlice; + } + if (this->primitive_->value.type != schema::PrimitiveType_StridedSlice) { + MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::StridedSliceT(); + attr->beginMask = GetValue(prim.GetAttr("begin_mask")); + attr->endMask = GetValue(prim.GetAttr("end_mask")); + attr->ellipsisMask = GetValue(prim.GetAttr("ellipsis_mask")); + attr->newAxisMask = GetValue(prim.GetAttr("new_axis_mask")); + attr->shrinkAxisMask = GetValue(prim.GetAttr("shrink_axis_mask")); + auto inputNodeFirst = inputs[kAnfPopulaterOne]; + std::vector beginVec; + GetAttrDataFromInput(inputNodeFirst, &beginVec); + attr->begin = beginVec; + + auto inputNodeSecond = inputs[kAnfPopulaterTwo]; + std::vector endVec; + GetAttrDataFromInput(inputNodeSecond, &endVec); + attr->end = endVec; + + auto inputNodeThird = inputs[kAnfPopulaterThree]; + std::vector strideVec; + GetAttrDataFromInput(inputNodeThird, &strideVec); + attr->stride = strideVec; + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} + #else int StridedSlice::GetBeginMask() const { return this->primitive_->value_as_StridedSlice()->beginMask(); } diff --git a/mindspore/lite/src/ops/strided_slice.h b/mindspore/lite/src/ops/strided_slice.h index 06b1f6e311..c135aab665 100644 --- a/mindspore/lite/src/ops/strided_slice.h +++ b/mindspore/lite/src/ops/strided_slice.h @@ -41,6 +41,7 @@ class StridedSlice : public PrimitiveC { void SetEnd(const std::vector &end); void SetStride(const std::vector &stride); void SetIsScale(const std::vector &is_scale); + int UnPackAttr(const Primitive &prim, const std::vector &inputs); #else StridedSlice() = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc index 34ab40cd02..9ee8723fb6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc @@ -55,7 +55,7 @@ int GatherInt8CPUKernel::DoGather(int task_id) { auto input_ptr = reinterpret_cast(input_tensor->MutableData()); auto output_ptr = reinterpret_cast(out_tensor->MutableData()); - auto indices_ptr = reinterpret_cast(out_tensor->MutableData()); + auto indices_ptr = reinterpret_cast(indices_tensor->MutableData()); auto in_shape = input_tensor->shape(); int in_rank = in_shape.size(); diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 620ecdc30f..b744d7d63a 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -90,7 +90,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { MS_LOG(ERROR) << "Export to meta graph return nullptr"; return nullptr; } - // transform transform->SetGraphDef(meta_graph); transform->CreateQuantizer(flag); diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index a19d8b61eb..9846087828 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -129,7 +129,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { MS_ASSERT(false); } - auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); + auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); if (quantParamCalcer == nullptr) { MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc index aea5e6f98e..da3a9c84ef 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -439,61 +439,52 @@ class CalcActivation : public QuantParamCalcer { } } }; -QuantParamCalcRegister::~QuantParamCalcRegister() { - for (auto ite : _registerMap) { - if (ite.second != nullptr) { - delete ite.second; - ite.second = nullptr; - } - } -} +QuantParamCalcRegister::~QuantParamCalcRegister() {} + QuantParamCalcRegister::QuantParamCalcRegister() { bool hasError = false; - std::unique_ptr baseCalcer(new (std::nothrow) QuantParamCalcer()); + std::shared_ptr baseCalcer = std::make_shared(); if (baseCalcer == nullptr) { MS_LOG(ERROR) << "new QuantParamCalcer failed"; hasError = true; } - std::unique_ptr commonCalcer(new (std::nothrow) CommonCalcer()); + std::shared_ptr commonCalcer = std::make_shared(); if (commonCalcer == nullptr) { MS_LOG(ERROR) << "new commonCalcer failed"; hasError = true; } - std::unique_ptr linearCalcer(new (std::nothrow) LinearCalcer()); + std::shared_ptr linearCalcer = std::make_shared(); if (linearCalcer == nullptr) { MS_LOG(ERROR) << "new linearCalcer failed"; hasError = true; } if (!hasError) { - _registerMap[schema::PrimitiveType_Concat] = new CalcConcat(); - _registerMap[schema::PrimitiveType_Activation] = new CalcActivation(); - _registerMap[schema::PrimitiveType_Add] = new CalcAdd(); - _registerMap[schema::PrimitiveType_Mul] = commonCalcer.get(); - _registerMap[schema::PrimitiveType_Conv2D] = commonCalcer.get(); - _registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer.get(); - _registerMap[schema::PrimitiveType_Pooling] = linearCalcer.get(); - _registerMap[schema::PrimitiveType_Resize] = linearCalcer.get(); - _registerMap[schema::PrimitiveType_Reshape] = linearCalcer.get(); - _registerMap[schema::PrimitiveType_Shape] = linearCalcer.get(); - _registerMap[schema::PrimitiveType_SoftMax] = new CalcToSet(0, 1); - _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer.get(); - _registerMap[schema::PrimitiveType_RealDiv] = new CalcRealDiv(); - _registerMap[schema::PrimitiveType_Reduce] = commonCalcer.get(); - _registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer.get(); - _registerMap[schema::PrimitiveType_Mean] = linearCalcer.get(); - _registerMap[schema::PrimitiveType_Transpose] = linearCalcer.get(); - _registerMap[schema::PrimitiveType_MatMul] = commonCalcer.get(); - _registerMap[schema::PrimitiveType_FullConnection] = commonCalcer.get(); - _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer.get(); - _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer.get(); + _registerMap[schema::PrimitiveType_Concat] = std::make_shared(); + _registerMap[schema::PrimitiveType_Activation] = std::make_shared(); + _registerMap[schema::PrimitiveType_Add] = std::make_shared(); + _registerMap[schema::PrimitiveType_Mul] = commonCalcer; + _registerMap[schema::PrimitiveType_Conv2D] = commonCalcer; + _registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer; + _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; + _registerMap[schema::PrimitiveType_Resize] = linearCalcer; + _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; + _registerMap[schema::PrimitiveType_Shape] = linearCalcer; + _registerMap[schema::PrimitiveType_SoftMax] = std::make_shared(0, 1); + _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer; + _registerMap[schema::PrimitiveType_RealDiv] = std::make_shared(); + _registerMap[schema::PrimitiveType_Reduce] = commonCalcer; + _registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer; + _registerMap[schema::PrimitiveType_Mean] = linearCalcer; + _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; + _registerMap[schema::PrimitiveType_MatMul] = commonCalcer; + _registerMap[schema::PrimitiveType_FullConnection] = commonCalcer; + _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; + _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; // detection_postprocess op's quant param will not infer only fetch from preNode or postNode // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. // if quantTransNode is inserted after detection_postprocess node, there will be some errors - _registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer.get(); - baseCalcer.release(); - linearCalcer.release(); - commonCalcer.release(); + _registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer; } } @@ -502,7 +493,7 @@ QuantParamCalcRegister *QuantParamCalcRegister::GetInstance() { return &instance; } -QuantParamCalcer *QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) { +std::shared_ptr QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) { auto it = _registerMap.find(opType); if (it != _registerMap.end()) { return it->second; diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.h b/mindspore/lite/tools/converter/quantizer/calc_quant_param.h index 4eae30e4d7..441a116b82 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.h +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.h @@ -56,12 +56,12 @@ class LinearCalcer : public QuantParamCalcer { class QuantParamCalcRegister { public: virtual ~QuantParamCalcRegister(); - QuantParamCalcer *GetQuantParamCalcer(schema::PrimitiveType opType); + std::shared_ptr GetQuantParamCalcer(schema::PrimitiveType opType); static QuantParamCalcRegister *GetInstance(); private: QuantParamCalcRegister(); - std::unordered_map _registerMap; + std::unordered_map> _registerMap; }; } // namespace lite } // namespace mindspore