Merge pull request !7823 from cjh9368/add_node_dtype_casttags/v1.1.0
| @@ -67,23 +67,73 @@ static const std::vector<schema::PrimitiveType> fp32FullOpList = { | |||
| static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {}; | |||
| static const std::vector<schema::PrimitiveType> int8OpList = { | |||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, | |||
| schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, | |||
| schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, | |||
| schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, | |||
| schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation, | |||
| schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection, | |||
| schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin, | |||
| schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, | |||
| schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div, | |||
| schema::PrimitiveType_Mul, schema::PrimitiveType_Slice, | |||
| schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split, | |||
| schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, | |||
| schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK, | |||
| schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul, | |||
| schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D, | |||
| schema::PrimitiveType_Scale}; | |||
| static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Nchw2Nhwc, | |||
| schema::PrimitiveType_Nhwc2Nchw, | |||
| schema::PrimitiveType_Conv2D, | |||
| schema::PrimitiveType_DepthwiseConv2D, | |||
| schema::PrimitiveType_Add, | |||
| schema::PrimitiveType_Pooling, | |||
| schema::PrimitiveType_Concat, | |||
| schema::PrimitiveType_SoftMax, | |||
| schema::PrimitiveType_Reshape, | |||
| schema::PrimitiveType_Activation, | |||
| schema::PrimitiveType_Resize, | |||
| schema::PrimitiveType_FullConnection, | |||
| schema::PrimitiveType_ArgMax, | |||
| schema::PrimitiveType_ArgMin, | |||
| schema::PrimitiveType_BatchNorm, | |||
| schema::PrimitiveType_FusedBatchNorm, | |||
| schema::PrimitiveType_BiasAdd, | |||
| schema::PrimitiveType_Div, | |||
| schema::PrimitiveType_Mul, | |||
| schema::PrimitiveType_Slice, | |||
| schema::PrimitiveType_SoftMax, | |||
| schema::PrimitiveType_Split, | |||
| schema::PrimitiveType_Squeeze, | |||
| schema::PrimitiveType_Sub, | |||
| schema::PrimitiveType_StridedSlice, | |||
| schema::PrimitiveType_TopK, | |||
| schema::PrimitiveType_Unsqueeze, | |||
| schema::PrimitiveType_MatMul, | |||
| schema::PrimitiveType_Pad, | |||
| schema::PrimitiveType_DeConv2D, | |||
| schema::PrimitiveType_Scale, | |||
| schema::PrimitiveType_Cast, | |||
| schema::PrimitiveType_Shape, | |||
| schema::PrimitiveType_ExpandDims, | |||
| schema::PrimitiveType_BatchToSpace, | |||
| schema::PrimitiveType_BatchToSpaceND, | |||
| schema::PrimitiveType_Reduce, | |||
| schema::PrimitiveType_Mean, | |||
| schema::PrimitiveType_Round, | |||
| schema::PrimitiveType_Floor, | |||
| schema::PrimitiveType_Ceil, | |||
| schema::PrimitiveType_Abs, | |||
| schema::PrimitiveType_Sin, | |||
| schema::PrimitiveType_Cos, | |||
| schema::PrimitiveType_Log, | |||
| schema::PrimitiveType_Sqrt, | |||
| schema::PrimitiveType_Rsqrt, | |||
| schema::PrimitiveType_Square, | |||
| schema::PrimitiveType_LogicalNot, | |||
| schema::PrimitiveType_SpaceToBatch, | |||
| schema::PrimitiveType_SpaceToBatchND, | |||
| schema::PrimitiveType_DepthToSpace, | |||
| schema::PrimitiveType_Power, | |||
| schema::PrimitiveType_GatherNd, | |||
| schema::PrimitiveType_LeakyReLU, | |||
| schema::PrimitiveType_Gather, | |||
| schema::PrimitiveType_Equal, | |||
| schema::PrimitiveType_NotEqual, | |||
| schema::PrimitiveType_LessEqual, | |||
| schema::PrimitiveType_Greater, | |||
| schema::PrimitiveType_GreaterEqual, | |||
| schema::PrimitiveType_Eltwise, | |||
| schema::PrimitiveType_DeDepthwiseConv2D, | |||
| schema::PrimitiveType_DetectionPostProcess, | |||
| schema::PrimitiveType_Crop, | |||
| schema::PrimitiveType_PriorBox, | |||
| schema::PrimitiveType_QuantDTypeCast}; | |||
| static const std::vector<schema::PrimitiveType> needInsertOpList = { | |||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | |||
| @@ -40,6 +40,12 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | |||
| MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; | |||
| return status; | |||
| } | |||
| status = DoNodeInoutDTypeTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -126,6 +132,51 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| return RET_OK; | |||
| } | |||
| STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| // insert transNode before and after existNode | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) || (*iter)->quantType != QuantType_AwareTraining) { | |||
| continue; | |||
| } | |||
| auto nodeName = (*iter)->name; | |||
| if ((*iter)->inputIndex.empty()) { | |||
| MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least"; | |||
| return RET_ERROR; | |||
| } | |||
| STATUS status; | |||
| // insert pre | |||
| for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { | |||
| MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); | |||
| auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); | |||
| if (preTensor->dataType != TypeId::kNumberTypeInt8) { | |||
| continue; | |||
| } | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| // insert post | |||
| for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { | |||
| auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i)); | |||
| if (postTensor->dataType != TypeId::kNumberTypeInt8) { | |||
| continue; | |||
| } | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| (*iter)->quantType = QuantType_QUANT_NONE; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, | |||
| size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) { | |||
| MS_ASSERT((*existNodeIter) != nullptr); | |||
| @@ -45,6 +45,7 @@ class DTypeTransPass : public GraphPass { | |||
| STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph); | |||
| STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph); | |||
| NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | |||
| DTypeTransNodeType nodeType, STATUS *errorCode); | |||
| @@ -30,6 +30,10 @@ STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) { | |||
| if (node->quantType == schema::QuantType_WeightQuant) { | |||
| continue; | |||
| } | |||
| DetermineNodeQuantType(*graph, node.get()); | |||
| if (node->quantType == schema::QuantType_AwareTraining) { | |||
| continue; | |||
| } | |||
| if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { | |||
| MS_ASSERT(false); | |||
| @@ -38,14 +42,14 @@ STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) { | |||
| if (quantParamCalcer == nullptr) { | |||
| MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() | |||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | |||
| node->quantType = static_cast<schema::QuantType>(schema::QuantType_QUANT_NONE); | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| } else { | |||
| auto status = quantParamCalcer->Calc(graph, *node); | |||
| if (status != RET_OK) { | |||
| MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| } else { | |||
| DetermineNodeQuantType(*graph, node.get()); | |||
| node->quantType = schema::QuantType_AwareTraining; | |||
| } | |||
| } | |||
| } | |||
| @@ -77,7 +81,7 @@ void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph | |||
| } | |||
| } | |||
| if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*cnode))) { | |||
| if (canQuant) { | |||
| cnode->quantType = schema::QuantType_AwareTraining; | |||
| } else { | |||
| cnode->quantType = schema::QuantType_QUANT_NONE; | |||