| @@ -232,6 +232,25 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||||
| vecOutputQuantParam->emplace_back(quants); | vecOutputQuantParam->emplace_back(quants); | ||||
| } | } | ||||
| } | } | ||||
| void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector<int> *data) { | |||||
| if (inputNode->isa<ValueNode>()) { | |||||
| auto valNode = inputNode->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(valNode != nullptr); | |||||
| auto val = valNode->value(); | |||||
| MS_ASSERT(val != nullptr); | |||||
| if (val->isa<ValueTuple>()) { | |||||
| auto tuple = val->cast<ValueTuplePtr>(); | |||||
| MS_ASSERT(tuple != nullptr); | |||||
| for (size_t i = 0; i < tuple->size(); i++) { | |||||
| auto elem = tuple->value()[i]->cast<Int32ImmPtr>(); | |||||
| MS_ASSERT(elem != nullptr); | |||||
| data->emplace_back(static_cast<int>(elem->value())); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; } | schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; } | ||||
| void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; } | void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; } | ||||
| @@ -113,6 +113,8 @@ class PrimitiveC : public mindspore::Primitive { | |||||
| static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive); | static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive); | ||||
| void GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector<int> *data); | |||||
| static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | ||||
| const schema::QuantType &quantType); | const schema::QuantType &quantType); | ||||
| void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, | void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, | ||||
| @@ -35,6 +35,42 @@ void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) { | |||||
| this->primitive_->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio; | this->primitive_->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio; | ||||
| } | } | ||||
| int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Resize; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Resize) { | |||||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::ResizeT(); | |||||
| if (prim.instance_name() == "ResizeNearestNeighbor") { | |||||
| attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR; | |||||
| } else if (prim.instance_name() == "ResizeBilinear") { | |||||
| attr->method = schema::ResizeMethod_BILINEAR; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong resize type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> targetSize = GetValue<std::vector<int>>(prim.GetAttr("size")); | |||||
| attr->newHeight = targetSize[0]; | |||||
| attr->newWidth = targetSize[1]; | |||||
| attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners")); | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int Resize::GetFormat() const { return this->primitive_->value_as_Resize()->format(); } | int Resize::GetFormat() const { return this->primitive_->value_as_Resize()->format(); } | ||||
| @@ -37,6 +37,7 @@ class Resize : public PrimitiveC { | |||||
| void SetNewWidth(int64_t new_width); | void SetNewWidth(int64_t new_width); | ||||
| void SetAlignCorners(bool align_corners); | void SetAlignCorners(bool align_corners); | ||||
| void SetPreserveAspectRatio(bool preserve_aspect_ratio); | void SetPreserveAspectRatio(bool preserve_aspect_ratio); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs); | |||||
| #else | #else | ||||
| Resize() = default; | Resize() = default; | ||||
| @@ -49,6 +49,49 @@ void StridedSlice::SetIsScale(const std::vector<int> &is_scale) { | |||||
| this->primitive_->value.AsStridedSlice()->isScale = is_scale; | this->primitive_->value.AsStridedSlice()->isScale = is_scale; | ||||
| } | } | ||||
| int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_StridedSlice; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_StridedSlice) { | |||||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::StridedSliceT(); | |||||
| attr->beginMask = GetValue<int>(prim.GetAttr("begin_mask")); | |||||
| attr->endMask = GetValue<int>(prim.GetAttr("end_mask")); | |||||
| attr->ellipsisMask = GetValue<int>(prim.GetAttr("ellipsis_mask")); | |||||
| attr->newAxisMask = GetValue<int>(prim.GetAttr("new_axis_mask")); | |||||
| attr->shrinkAxisMask = GetValue<int>(prim.GetAttr("shrink_axis_mask")); | |||||
| auto inputNodeFirst = inputs[kAnfPopulaterOne]; | |||||
| std::vector<int> beginVec; | |||||
| GetAttrDataFromInput(inputNodeFirst, &beginVec); | |||||
| attr->begin = beginVec; | |||||
| auto inputNodeSecond = inputs[kAnfPopulaterTwo]; | |||||
| std::vector<int> endVec; | |||||
| GetAttrDataFromInput(inputNodeSecond, &endVec); | |||||
| attr->end = endVec; | |||||
| auto inputNodeThird = inputs[kAnfPopulaterThree]; | |||||
| std::vector<int> strideVec; | |||||
| GetAttrDataFromInput(inputNodeThird, &strideVec); | |||||
| attr->stride = strideVec; | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int StridedSlice::GetBeginMask() const { return this->primitive_->value_as_StridedSlice()->beginMask(); } | int StridedSlice::GetBeginMask() const { return this->primitive_->value_as_StridedSlice()->beginMask(); } | ||||
| @@ -41,6 +41,7 @@ class StridedSlice : public PrimitiveC { | |||||
| void SetEnd(const std::vector<int> &end); | void SetEnd(const std::vector<int> &end); | ||||
| void SetStride(const std::vector<int> &stride); | void SetStride(const std::vector<int> &stride); | ||||
| void SetIsScale(const std::vector<int> &is_scale); | void SetIsScale(const std::vector<int> &is_scale); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs); | |||||
| #else | #else | ||||
| StridedSlice() = default; | StridedSlice() = default; | ||||
| @@ -55,7 +55,7 @@ int GatherInt8CPUKernel::DoGather(int task_id) { | |||||
| auto input_ptr = reinterpret_cast<int8_t *>(input_tensor->MutableData()); | auto input_ptr = reinterpret_cast<int8_t *>(input_tensor->MutableData()); | ||||
| auto output_ptr = reinterpret_cast<int8_t *>(out_tensor->MutableData()); | auto output_ptr = reinterpret_cast<int8_t *>(out_tensor->MutableData()); | ||||
| auto indices_ptr = reinterpret_cast<int32_t *>(out_tensor->MutableData()); | |||||
| auto indices_ptr = reinterpret_cast<int32_t *>(indices_tensor->MutableData()); | |||||
| auto in_shape = input_tensor->shape(); | auto in_shape = input_tensor->shape(); | ||||
| int in_rank = in_shape.size(); | int in_rank = in_shape.size(); | ||||
| @@ -90,7 +90,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||||
| MS_LOG(ERROR) << "Export to meta graph return nullptr"; | MS_LOG(ERROR) << "Export to meta graph return nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // transform | // transform | ||||
| transform->SetGraphDef(meta_graph); | transform->SetGraphDef(meta_graph); | ||||
| transform->CreateQuantizer(flag); | transform->CreateQuantizer(flag); | ||||
| @@ -129,7 +129,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { | |||||
| GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { | GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { | ||||
| MS_ASSERT(false); | MS_ASSERT(false); | ||||
| } | } | ||||
| auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); | |||||
| auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); | |||||
| if (quantParamCalcer == nullptr) { | if (quantParamCalcer == nullptr) { | ||||
| MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() | MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() | ||||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | ||||
| @@ -439,61 +439,52 @@ class CalcActivation : public QuantParamCalcer { | |||||
| } | } | ||||
| } | } | ||||
| }; | }; | ||||
| QuantParamCalcRegister::~QuantParamCalcRegister() { | |||||
| for (auto ite : _registerMap) { | |||||
| if (ite.second != nullptr) { | |||||
| delete ite.second; | |||||
| ite.second = nullptr; | |||||
| } | |||||
| } | |||||
| } | |||||
| QuantParamCalcRegister::~QuantParamCalcRegister() {} | |||||
| QuantParamCalcRegister::QuantParamCalcRegister() { | QuantParamCalcRegister::QuantParamCalcRegister() { | ||||
| bool hasError = false; | bool hasError = false; | ||||
| std::unique_ptr<QuantParamCalcer> baseCalcer(new (std::nothrow) QuantParamCalcer()); | |||||
| std::shared_ptr<QuantParamCalcer> baseCalcer = std::make_shared<QuantParamCalcer>(); | |||||
| if (baseCalcer == nullptr) { | if (baseCalcer == nullptr) { | ||||
| MS_LOG(ERROR) << "new QuantParamCalcer failed"; | MS_LOG(ERROR) << "new QuantParamCalcer failed"; | ||||
| hasError = true; | hasError = true; | ||||
| } | } | ||||
| std::unique_ptr<CommonCalcer> commonCalcer(new (std::nothrow) CommonCalcer()); | |||||
| std::shared_ptr<QuantParamCalcer> commonCalcer = std::make_shared<CommonCalcer>(); | |||||
| if (commonCalcer == nullptr) { | if (commonCalcer == nullptr) { | ||||
| MS_LOG(ERROR) << "new commonCalcer failed"; | MS_LOG(ERROR) << "new commonCalcer failed"; | ||||
| hasError = true; | hasError = true; | ||||
| } | } | ||||
| std::unique_ptr<LinearCalcer> linearCalcer(new (std::nothrow) LinearCalcer()); | |||||
| std::shared_ptr<QuantParamCalcer> linearCalcer = std::make_shared<LinearCalcer>(); | |||||
| if (linearCalcer == nullptr) { | if (linearCalcer == nullptr) { | ||||
| MS_LOG(ERROR) << "new linearCalcer failed"; | MS_LOG(ERROR) << "new linearCalcer failed"; | ||||
| hasError = true; | hasError = true; | ||||
| } | } | ||||
| if (!hasError) { | if (!hasError) { | ||||
| _registerMap[schema::PrimitiveType_Concat] = new CalcConcat(); | |||||
| _registerMap[schema::PrimitiveType_Activation] = new CalcActivation(); | |||||
| _registerMap[schema::PrimitiveType_Add] = new CalcAdd(); | |||||
| _registerMap[schema::PrimitiveType_Mul] = commonCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_Conv2D] = commonCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_Pooling] = linearCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_Resize] = linearCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_Reshape] = linearCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_Shape] = linearCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_SoftMax] = new CalcToSet(0, 1); | |||||
| _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_RealDiv] = new CalcRealDiv(); | |||||
| _registerMap[schema::PrimitiveType_Reduce] = commonCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_Mean] = linearCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_Transpose] = linearCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_MatMul] = commonCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_FullConnection] = commonCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer.get(); | |||||
| _registerMap[schema::PrimitiveType_Concat] = std::make_shared<CalcConcat>(); | |||||
| _registerMap[schema::PrimitiveType_Activation] = std::make_shared<CalcActivation>(); | |||||
| _registerMap[schema::PrimitiveType_Add] = std::make_shared<CalcAdd>(); | |||||
| _registerMap[schema::PrimitiveType_Mul] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_Conv2D] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; | |||||
| _registerMap[schema::PrimitiveType_Resize] = linearCalcer; | |||||
| _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; | |||||
| _registerMap[schema::PrimitiveType_Shape] = linearCalcer; | |||||
| _registerMap[schema::PrimitiveType_SoftMax] = std::make_shared<CalcToSet>(0, 1); | |||||
| _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer; | |||||
| _registerMap[schema::PrimitiveType_RealDiv] = std::make_shared<CalcRealDiv>(); | |||||
| _registerMap[schema::PrimitiveType_Reduce] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_Mean] = linearCalcer; | |||||
| _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; | |||||
| _registerMap[schema::PrimitiveType_MatMul] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_FullConnection] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; | |||||
| _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; | |||||
| // detection_postprocess op's quant param will not infer only fetch from preNode or postNode | // detection_postprocess op's quant param will not infer only fetch from preNode or postNode | ||||
| // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. | // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. | ||||
| // if quantTransNode is inserted after detection_postprocess node, there will be some errors | // if quantTransNode is inserted after detection_postprocess node, there will be some errors | ||||
| _registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer.get(); | |||||
| baseCalcer.release(); | |||||
| linearCalcer.release(); | |||||
| commonCalcer.release(); | |||||
| _registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer; | |||||
| } | } | ||||
| } | } | ||||
| @@ -502,7 +493,7 @@ QuantParamCalcRegister *QuantParamCalcRegister::GetInstance() { | |||||
| return &instance; | return &instance; | ||||
| } | } | ||||
| QuantParamCalcer *QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) { | |||||
| std::shared_ptr<QuantParamCalcer> QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) { | |||||
| auto it = _registerMap.find(opType); | auto it = _registerMap.find(opType); | ||||
| if (it != _registerMap.end()) { | if (it != _registerMap.end()) { | ||||
| return it->second; | return it->second; | ||||
| @@ -56,12 +56,12 @@ class LinearCalcer : public QuantParamCalcer { | |||||
| class QuantParamCalcRegister { | class QuantParamCalcRegister { | ||||
| public: | public: | ||||
| virtual ~QuantParamCalcRegister(); | virtual ~QuantParamCalcRegister(); | ||||
| QuantParamCalcer *GetQuantParamCalcer(schema::PrimitiveType opType); | |||||
| std::shared_ptr<QuantParamCalcer> GetQuantParamCalcer(schema::PrimitiveType opType); | |||||
| static QuantParamCalcRegister *GetInstance(); | static QuantParamCalcRegister *GetInstance(); | ||||
| private: | private: | ||||
| QuantParamCalcRegister(); | QuantParamCalcRegister(); | ||||
| std::unordered_map<schema::PrimitiveType, QuantParamCalcer *> _registerMap; | |||||
| std::unordered_map<schema::PrimitiveType, std::shared_ptr<QuantParamCalcer>> _registerMap; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||