| @@ -46,6 +46,12 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||||
| SetOutputQuantParam(vecOutputQuantParam); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -260,7 +260,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | if (GetQuantType() == schema::QuantType_AwareTraining) { | ||||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | ||||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | ||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); | |||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||||
| SetInputQuantParam(vecInputQuantParam); | SetInputQuantParam(vecInputQuantParam); | ||||
| SetOutputQuantParam(vecOutputQuantParam); | SetOutputQuantParam(vecOutputQuantParam); | ||||
| } | } | ||||
| @@ -130,7 +130,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i | |||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | if (GetQuantType() == schema::QuantType_AwareTraining) { | ||||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | ||||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | ||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); | |||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||||
| SetInputQuantParam(vecInputQuantParam); | SetInputQuantParam(vecInputQuantParam); | ||||
| SetOutputQuantParam(vecOutputQuantParam); | SetOutputQuantParam(vecOutputQuantParam); | ||||
| } | } | ||||
| @@ -140,7 +140,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | if (GetQuantType() == schema::QuantType_AwareTraining) { | ||||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | ||||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | ||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); | |||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||||
| SetInputQuantParam(vecInputQuantParam); | SetInputQuantParam(vecInputQuantParam); | ||||
| SetOutputQuantParam(vecOutputQuantParam); | SetOutputQuantParam(vecOutputQuantParam); | ||||
| } | } | ||||
| @@ -60,7 +60,7 @@ int MatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | if (GetQuantType() == schema::QuantType_AwareTraining) { | ||||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | ||||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | ||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); | |||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||||
| SetInputQuantParam(vecInputQuantParam); | SetInputQuantParam(vecInputQuantParam); | ||||
| SetOutputQuantParam(vecOutputQuantParam); | SetOutputQuantParam(vecOutputQuantParam); | ||||
| } | } | ||||
| @@ -158,7 +158,8 @@ void PrimitiveC::CalQuantParam(const double &mean, const double &stdDev, float * | |||||
| void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | ||||
| std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, | std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, | ||||
| std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) { | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto narrow_range = prim.GetAttr("narrow_range"); | auto narrow_range = prim.GetAttr("narrow_range"); | ||||
| bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | ||||
| auto num_bits = prim.GetAttr("num_bits"); | auto num_bits = prim.GetAttr("num_bits"); | ||||
| @@ -179,12 +180,14 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||||
| } else { | } else { | ||||
| auto inputMin = prim.GetAttr("input_minq"); | auto inputMin = prim.GetAttr("input_minq"); | ||||
| auto inputMax = prim.GetAttr("input_maxq"); | auto inputMax = prim.GetAttr("input_maxq"); | ||||
| auto inputMinPtr = inputMin->cast<TensorPtr>(); | |||||
| auto inputMaxPtr = inputMax->cast<TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(inputMinPtr->data_c()); | |||||
| float *maxBuf = static_cast<float *>(inputMaxPtr->data_c()); | |||||
| quantParam.min = *minBuf; | |||||
| quantParam.max = *maxBuf; | |||||
| if (inputMin != nullptr && inputMax != nullptr) { | |||||
| auto inputMinPtr = inputMin->cast<TensorPtr>(); | |||||
| auto inputMaxPtr = inputMax->cast<TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(inputMinPtr->data_c()); | |||||
| float *maxBuf = static_cast<float *>(inputMaxPtr->data_c()); | |||||
| quantParam.min = *minBuf; | |||||
| quantParam.max = *maxBuf; | |||||
| } | |||||
| } | } | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | ||||
| numbitsRangeQuantParam); | numbitsRangeQuantParam); | ||||
| @@ -212,13 +215,15 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||||
| vecInputQuantParam->emplace_back(quants); | vecInputQuantParam->emplace_back(quants); | ||||
| } | } | ||||
| quants.clear(); | |||||
| quantParam.min = 0.0; | |||||
| quantParam.max = 0.0; | |||||
| quantParam.zeroPoint = 0; | |||||
| quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale; | |||||
| quants.emplace_back(quantParam); | |||||
| vecInputQuantParam->emplace_back(quants); | |||||
| if (vecInputQuantParam->size() == kDoubleNum) { | |||||
| quants.clear(); | |||||
| quantParam.min = 0.0; | |||||
| quantParam.max = 0.0; | |||||
| quantParam.zeroPoint = 0; | |||||
| quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale; | |||||
| quants.emplace_back(quantParam); | |||||
| vecInputQuantParam->emplace_back(quants); | |||||
| } | |||||
| quants.clear(); | quants.clear(); | ||||
| auto outputMin = prim.GetAttr("output_minq"); | auto outputMin = prim.GetAttr("output_minq"); | ||||
| @@ -39,8 +39,8 @@ constexpr uint32_t kDoubleNum = 2; | |||||
| constexpr uint32_t kMultiNum = 3; | constexpr uint32_t kMultiNum = 3; | ||||
| constexpr uint32_t kDimension_4d = 4; | constexpr uint32_t kDimension_4d = 4; | ||||
| const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt32, | |||||
| kNumberTypeFloat32, kNumberTypeFloat16}; | |||||
| const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat32, | |||||
| kNumberTypeFloat16}; | |||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| using TensorPtr = std::shared_ptr<mindspore::tensor::Tensor>; | using TensorPtr = std::shared_ptr<mindspore::tensor::Tensor>; | ||||
| @@ -119,7 +119,8 @@ class PrimitiveC : public mindspore::Primitive { | |||||
| static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | ||||
| const schema::QuantType &quantType); | const schema::QuantType &quantType); | ||||
| void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, | void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, | ||||
| std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam); | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam, | |||||
| const std::vector<AnfNodePtr> &inputs); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); | void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); | ||||
| protected: | protected: | ||||
| @@ -98,29 +98,28 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||||
| // activation | // activation | ||||
| auto input_quant_params = primitive->GetInputQuantParams(); | auto input_quant_params = primitive->GetInputQuantParams(); | ||||
| auto node_type = (schema::PrimitiveType)primitive->Type(); | auto node_type = (schema::PrimitiveType)primitive->Type(); | ||||
| if (input_quant_params.empty()) { | |||||
| MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty"; | |||||
| return RET_OK; | |||||
| } | |||||
| for (size_t i = 0; i < input_quant_params.size(); i++) { | |||||
| if (i >= dst_node->inputIndex.size()) { | |||||
| MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size() | |||||
| << " quant_params; but only " << dst_node->inputIndex.size() << " input"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| auto activate_index = dst_node->inputIndex[i]; | |||||
| auto tensor_input = meta_graph->allTensors[activate_index].get(); | |||||
| if (tensor_input->quantParams.empty()) { | |||||
| for (auto input_quant_param : input_quant_params[i]) { | |||||
| std::unique_ptr<schema::QuantParamT> input_quant_param_ptr = | |||||
| std::make_unique<schema::QuantParamT>(input_quant_param); | |||||
| MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale | |||||
| << " zp: " << input_quant_param_ptr->zeroPoint; | |||||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | |||||
| if (!input_quant_params.empty()) { | |||||
| for (size_t i = 0; i < input_quant_params.size(); i++) { | |||||
| if (i >= dst_node->inputIndex.size()) { | |||||
| MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size() | |||||
| << " quant_params; but only " << dst_node->inputIndex.size() << " input"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| auto activate_index = dst_node->inputIndex[i]; | |||||
| auto tensor_input = meta_graph->allTensors[activate_index].get(); | |||||
| if (tensor_input->quantParams.empty()) { | |||||
| for (auto input_quant_param : input_quant_params[i]) { | |||||
| std::unique_ptr<schema::QuantParamT> input_quant_param_ptr = | |||||
| std::make_unique<schema::QuantParamT>(input_quant_param); | |||||
| MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale | |||||
| << " zp: " << input_quant_param_ptr->zeroPoint; | |||||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } else { | |||||
| MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty"; | |||||
| } | } | ||||
| // output | // output | ||||
| auto output_index = dst_node->outputIndex[0]; | auto output_index = dst_node->outputIndex[0]; | ||||
| auto tensor_output = meta_graph->allTensors[output_index].get(); | auto tensor_output = meta_graph->allTensors[output_index].get(); | ||||
| @@ -171,7 +170,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> & | |||||
| } | } | ||||
| int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *return_node) { | |||||
| schema::CNodeT *return_node) { | |||||
| MS_ASSERT(nullptr != meta_graph); | MS_ASSERT(nullptr != meta_graph); | ||||
| MS_ASSERT(nullptr != return_node); | MS_ASSERT(nullptr != return_node); | ||||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | for (size_t i = 1; i < cnode->inputs().size(); i++) { | ||||
| @@ -210,9 +209,9 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||||
| if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem || | if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem || | ||||
| primitive_c->Type() == schema::PrimitiveType_MakeTuple | primitive_c->Type() == schema::PrimitiveType_MakeTuple | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| || primitive_c->Type() == schema::PrimitiveType_Depend | |||||
| || primitive_c->Type() == schema::PrimitiveType_Depend | |||||
| #endif | #endif | ||||
| ) { | |||||
| ) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| RemoveIfMakeTuple(cnode); | RemoveIfMakeTuple(cnode); | ||||
| @@ -403,8 +402,7 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| if (value_track->isa<Int32Imm>()) { | if (value_track->isa<Int32Imm>()) { | ||||
| shape.push_back((GetValue<int>(value_track))); | shape.push_back((GetValue<int>(value_track))); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " | |||||
| << value_track->ToString() << "."; | |||||
| MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << "."; | |||||
| } | } | ||||
| } | } | ||||
| if (shape.size()) { | if (shape.size()) { | ||||
| @@ -417,10 +415,10 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | ||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | ||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | ||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << "."; | |||||
| } | } | ||||
| } else { | |||||
| MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << "."; | |||||
| } | |||||
| #endif | #endif | ||||
| } else if (value->isa<Number>()) { | } else if (value->isa<Number>()) { | ||||
| MS_LOG(INFO) << "Value is a number."; | MS_LOG(INFO) << "Value is a number."; | ||||
| @@ -54,8 +54,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpDualInputList = { | |||||
| static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| schema::PrimitiveType_PoolingGrad, | |||||
| schema::PrimitiveType_ActivationGrad | |||||
| schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| @@ -66,20 +65,21 @@ static const std::vector<schema::PrimitiveType> fp32FullOpList = { | |||||
| static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {}; | static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {}; | ||||
| static const std::vector<schema::PrimitiveType> int8OpList = { | static const std::vector<schema::PrimitiveType> int8OpList = { | ||||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, | |||||
| schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, | |||||
| schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, | |||||
| schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, | |||||
| schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation, | |||||
| schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection, | |||||
| schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin, | |||||
| schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, | |||||
| schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div, | |||||
| schema::PrimitiveType_Mul, schema::PrimitiveType_Slice, | |||||
| schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split, | |||||
| schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, | |||||
| schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze, | |||||
| schema::PrimitiveType_MatMul, schema::PrimitiveType_Pad}; | |||||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, | |||||
| schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, | |||||
| schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, | |||||
| schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, | |||||
| schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation, | |||||
| schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection, | |||||
| schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin, | |||||
| schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, | |||||
| schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div, | |||||
| schema::PrimitiveType_Mul, schema::PrimitiveType_Slice, | |||||
| schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split, | |||||
| schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, | |||||
| schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK, | |||||
| schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul, | |||||
| schema::PrimitiveType_Pad}; | |||||
| static const std::vector<schema::PrimitiveType> needInsertOpList = { | static const std::vector<schema::PrimitiveType> needInsertOpList = { | ||||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" | #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" | ||||
| #include <string> | #include <string> | ||||
| #include <set> | |||||
| #include "tools/common/converter_op_utils.h" | #include "tools/common/converter_op_utils.h" | ||||
| #include "tools/common/node_util.h" | #include "tools/common/node_util.h" | ||||
| #include "src/common/common.h" | #include "src/common/common.h" | ||||
| @@ -26,6 +27,9 @@ namespace lite { | |||||
| #define kMinInputNum 1 | #define kMinInputNum 1 | ||||
| #define kOutputNum 1 | #define kOutputNum 1 | ||||
| static const std::set<schema::PrimitiveType> NoNeedDtypeTransList = { | |||||
| PrimitiveType_QuantDTypeCast, PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; | |||||
| STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| @@ -134,7 +138,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||||
| if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { | if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) { | |||||
| auto iterType = GetCNodeTType(**iter); | |||||
| if (NoNeedDtypeTransList.find(iterType) != NoNeedDtypeTransList.end()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| bool needInsertPost = true; | bool needInsertPost = true; | ||||
| @@ -167,7 +167,11 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||||
| auto &outTensor = graph->allTensors.at(node.outputIndex.at(i)); | auto &outTensor = graph->allTensors.at(node.outputIndex.at(i)); | ||||
| MS_ASSERT(outTensor != nullptr); | MS_ASSERT(outTensor != nullptr); | ||||
| auto outQuantParam = GetTensorQuantParam(outTensor); | auto outQuantParam = GetTensorQuantParam(outTensor); | ||||
| if (outQuantParam == nullptr || outQuantParam->inited) { | |||||
| if (outQuantParam == nullptr) { | |||||
| outTensor->quantParams.emplace_back(std::move(inQuantParam)); | |||||
| continue; | |||||
| } | |||||
| if (outQuantParam->inited) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| outTensor->quantParams.front() = std::move(inQuantParam); | outTensor->quantParams.front() = std::move(inQuantParam); | ||||
| @@ -232,7 +236,7 @@ class CalcConcat : public QuantParamCalcer { | |||||
| MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!"; | MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| outTensor->quantParams.front() = std::move(outQuantParam); | |||||
| outTensor->quantParams.emplace_back(std::move(outQuantParam)); | |||||
| outputParamDone++; | outputParamDone++; | ||||
| } | } | ||||
| @@ -417,7 +421,7 @@ class CalcToSet : public QuantParamCalcer { | |||||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | ||||
| auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | ||||
| MS_ASSERT(outTensor != nullptr); | MS_ASSERT(outTensor != nullptr); | ||||
| outTensor->quantParams.front() = std::move(quantParam); | |||||
| outTensor->quantParams.emplace_back(std::move(quantParam)); | |||||
| outputParamDone++; | outputParamDone++; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -475,6 +479,7 @@ QuantParamCalcRegister::QuantParamCalcRegister() { | |||||
| _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; | _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_Resize] = linearCalcer; | _registerMap[schema::PrimitiveType_Resize] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; | _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_StridedSlice] = linearCalcer; | |||||
| _registerMap[schema::PrimitiveType_Shape] = linearCalcer; | _registerMap[schema::PrimitiveType_Shape] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_SoftMax] = std::make_shared<CalcToSet>(0, 1); | _registerMap[schema::PrimitiveType_SoftMax] = std::make_shared<CalcToSet>(0, 1); | ||||
| _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer; | _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer; | ||||