From: @zhengjun10 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -1,509 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" | |||
| #include <cfloat> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define kBatchNormFoldFusionPathLen6 6 | |||
| #define kBatchNormFoldFusionPathLen7 7 | |||
| STATUS BatchNormFoldFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } | |||
| STATUS BatchNormFoldFusionPass::DefinePattern() { | |||
| // with preNode | |||
| { | |||
| auto inputOp = std::make_shared<PatternOp>(); | |||
| inputOp->id = inputOpName; | |||
| inputOp->types = {schema::PrimitiveType_NONE}; | |||
| inputOp->isPlaceHold = true; | |||
| auto convOp1 = std::make_shared<PatternOp>(); | |||
| convOp1->id = convPatternOpName1; | |||
| convOp1->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; | |||
| convOp1->left = inputOp; | |||
| auto bnFoldOp = std::make_shared<PatternOp>(); | |||
| bnFoldOp->id = bnFoldOpName; | |||
| bnFoldOp->types = {schema::PrimitiveType_BatchNormFold}; | |||
| bnFoldOp->left = convOp1; | |||
| auto mulFoldOp = std::make_shared<PatternOp>(); | |||
| mulFoldOp->id = mulFoldOpName; | |||
| mulFoldOp->types = {schema::PrimitiveType_MulFold}; | |||
| mulFoldOp->left = bnFoldOp; | |||
| auto fakeQuantOp = std::make_shared<PatternOp>(); | |||
| fakeQuantOp->id = fakeQuantOpName; | |||
| fakeQuantOp->types = {schema::PrimitiveType_FakeQuantWithMinMax}; | |||
| fakeQuantOp->left = mulFoldOp; | |||
| auto convOp2 = std::make_shared<PatternOp>(); | |||
| convOp2->id = convPatternOpName2; | |||
| convOp2->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; | |||
| convOp2->left = fakeQuantOp; | |||
| convOp2->right = inputOp; | |||
| auto addFoldOp = std::make_shared<PatternOp>(); | |||
| addFoldOp->id = addFoldOpName; | |||
| addFoldOp->types = {schema::PrimitiveType_AddFold}; | |||
| addFoldOp->left = convOp2; | |||
| addFoldOp->right = bnFoldOp; | |||
| std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern(withPrePatternName)); | |||
| if (fusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new fusionPattern failed"; | |||
| return RET_ERROR; | |||
| } | |||
| fusionPattern->AddPatternOp(inputOp); | |||
| fusionPattern->AddPatternOp(convOp1); | |||
| fusionPattern->AddPatternOp(bnFoldOp); | |||
| fusionPattern->AddPatternOp(mulFoldOp); | |||
| fusionPattern->AddPatternOp(fakeQuantOp); | |||
| fusionPattern->AddPatternOp(convOp2); | |||
| fusionPattern->AddPatternOp(addFoldOp); | |||
| fusionPattern->Finish(); | |||
| this->patterns.emplace_back(fusionPattern.release()); | |||
| } | |||
| // no preNode | |||
| { | |||
| auto convOp1 = std::make_shared<PatternOp>(); | |||
| convOp1->id = convPatternOpName1; | |||
| convOp1->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; | |||
| auto bnFoldOp = std::make_shared<PatternOp>(); | |||
| bnFoldOp->id = bnFoldOpName; | |||
| bnFoldOp->types = {schema::PrimitiveType_BatchNormFold}; | |||
| bnFoldOp->left = convOp1; | |||
| auto mulFoldOp = std::make_shared<PatternOp>(); | |||
| mulFoldOp->id = mulFoldOpName; | |||
| mulFoldOp->types = {schema::PrimitiveType_MulFold}; | |||
| mulFoldOp->left = bnFoldOp; | |||
| auto fakeQuantOp = std::make_shared<PatternOp>(); | |||
| fakeQuantOp->id = fakeQuantOpName; | |||
| fakeQuantOp->types = {schema::PrimitiveType_FakeQuantWithMinMax}; | |||
| fakeQuantOp->left = mulFoldOp; | |||
| auto convOp2 = std::make_shared<PatternOp>(); | |||
| convOp2->id = convPatternOpName2; | |||
| convOp2->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; | |||
| convOp2->left = fakeQuantOp; | |||
| auto addFoldOp = std::make_shared<PatternOp>(); | |||
| addFoldOp->id = addFoldOpName; | |||
| addFoldOp->types = {schema::PrimitiveType_AddFold}; | |||
| addFoldOp->left = convOp2; | |||
| addFoldOp->right = bnFoldOp; | |||
| std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern(noPrePatternName)); | |||
| if (fusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new fusionPattern failed"; | |||
| return RET_ERROR; | |||
| } | |||
| fusionPattern->AddPatternOp(convOp1); | |||
| fusionPattern->AddPatternOp(bnFoldOp); | |||
| fusionPattern->AddPatternOp(mulFoldOp); | |||
| fusionPattern->AddPatternOp(fakeQuantOp); | |||
| fusionPattern->AddPatternOp(convOp2); | |||
| fusionPattern->AddPatternOp(addFoldOp); | |||
| fusionPattern->Finish(); | |||
| this->patterns.emplace_back(fusionPattern.release()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS BatchNormFoldFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (patternName == withPrePatternName) { | |||
| if (matchedPath.size() != kBatchNormFoldFusionPathLen7) { | |||
| MS_LOG(ERROR) << "BatchNormFold-Fusion should have seven NodeIndex in matchedPair"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| } else if (patternName == noPrePatternName) { | |||
| if (matchedPath.size() != kBatchNormFoldFusionPathLen6) { | |||
| MS_LOG(ERROR) << "BatchNormFold-Fusion should have six NodeIndex in matchedPair"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| } | |||
| auto status = FindNodes(graph, matchedPath); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "FindNodes failed: " << status; | |||
| return status; | |||
| } | |||
| status = CheckPath(graph, matchedPath); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CheckPath failed: " << status; | |||
| return status; | |||
| } | |||
| status = FindTensors(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "FindTensors failed: " << status; | |||
| return status; | |||
| } | |||
| status = GenNewWeightTensor(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GenNewWeightTensor failed: " << status; | |||
| return status; | |||
| } | |||
| status = GenNewBiasTensor(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GenNewBiasTensor failed: " << status; | |||
| return status; | |||
| } | |||
| status = IsolateNodes(graph, matchedPath); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateNodes failed: " << status; | |||
| return status; | |||
| } | |||
| UpdateConvWeights(); | |||
| status = DeleteConstTensors(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DeleteConstTensors failed: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS BatchNormFoldFusionPass::FindNodes(MetaGraphT *graph, | |||
| const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto preConvPath = matchedPath.at(convPatternOpName1); | |||
| auto bnFoldPath = matchedPath.at(bnFoldOpName); | |||
| auto mulFoldPath = matchedPath.at(mulFoldOpName); | |||
| auto fakeQuantPath = matchedPath.at(fakeQuantOpName); | |||
| auto convPath = matchedPath.at(convPatternOpName2); | |||
| auto addFoldPath = matchedPath.at(addFoldOpName); | |||
| MS_ASSERT(preConvPath != nullptr); | |||
| MS_ASSERT(bnFoldPath != nullptr); | |||
| MS_ASSERT(mulFoldPath != nullptr); | |||
| MS_ASSERT(fakeQuantPath != nullptr); | |||
| MS_ASSERT(convPath != nullptr); | |||
| MS_ASSERT(addFoldPath != nullptr); | |||
| if (preConvPath->subGraphIdx != bnFoldPath->subGraphIdx || preConvPath->subGraphIdx != mulFoldPath->subGraphIdx || | |||
| preConvPath->subGraphIdx != fakeQuantPath->subGraphIdx || preConvPath->subGraphIdx != convPath->subGraphIdx || | |||
| preConvPath->subGraphIdx != addFoldPath->subGraphIdx) { | |||
| MS_LOG(ERROR) << "matched nodes should from same subGraph"; | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(graph->nodes.size() > preConvPath->nodeIdx); | |||
| MS_ASSERT(graph->nodes.size() > bnFoldPath->nodeIdx); | |||
| MS_ASSERT(graph->nodes.size() > mulFoldPath->nodeIdx); | |||
| MS_ASSERT(graph->nodes.size() > fakeQuantPath->nodeIdx); | |||
| MS_ASSERT(graph->nodes.size() > convPath->nodeIdx); | |||
| MS_ASSERT(graph->nodes.size() > addFoldPath->nodeIdx); | |||
| preConv = graph->nodes.at(preConvPath->nodeIdx).get(); | |||
| bnFold = graph->nodes.at(bnFoldPath->nodeIdx).get(); | |||
| mulFold = graph->nodes.at(mulFoldPath->nodeIdx).get(); | |||
| fakeNode = graph->nodes.at(fakeQuantPath->nodeIdx).get(); | |||
| convNode = graph->nodes.at(convPath->nodeIdx).get(); | |||
| addFold = graph->nodes.at(addFoldPath->nodeIdx).get(); | |||
| MS_ASSERT(preConv != nullptr); | |||
| MS_ASSERT(bnFold != nullptr); | |||
| MS_ASSERT(mulFold != nullptr); | |||
| MS_ASSERT(fakeNode != nullptr); | |||
| MS_ASSERT(convNode != nullptr); | |||
| MS_ASSERT(addFold != nullptr); | |||
| return RET_OK; | |||
| } | |||
| STATUS BatchNormFoldFusionPass::FindTensors() { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(bnFold != nullptr); | |||
| MS_ASSERT(addFold != nullptr); | |||
| if (bnFold->inputIndex.size() != 4) { | |||
| MS_LOG(ERROR) << "BatchNormFold node should have 4 inputTensor, got " << bnFold->inputIndex.size() | |||
| << " input tensors"; | |||
| return RET_ERROR; | |||
| } | |||
| if (addFold->inputIndex.size() != 5) { | |||
| MS_LOG(ERROR) << "AddFold node should have 5 inputTensor, got " << addFold->inputIndex.size() << " input tensors"; | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(graph->allTensors.size() > bnFold->inputIndex.at(1)); | |||
| muTensor = graph->allTensors.at(bnFold->inputIndex.at(1)).get(); | |||
| MS_ASSERT(muTensor != nullptr); | |||
| MS_ASSERT(graph->allTensors.size() > bnFold->inputIndex.at(2)); | |||
| sigmaTensor = graph->allTensors.at(bnFold->inputIndex.at(2)).get(); | |||
| MS_ASSERT(sigmaTensor != nullptr); | |||
| MS_ASSERT(graph->allTensors.size() > addFold->inputIndex.at(1)); | |||
| betaTensor = graph->allTensors.at(addFold->inputIndex.at(1)).get(); | |||
| MS_ASSERT(betaTensor != nullptr); | |||
| MS_ASSERT(graph->allTensors.size() > addFold->inputIndex.at(2)); | |||
| gammaTensor = graph->allTensors.at(addFold->inputIndex.at(2)).get(); | |||
| MS_ASSERT(gammaTensor != nullptr); | |||
| if (betaTensor->dims.size() != 1) { | |||
| MS_LOG(ERROR) << "ConstTensor should have only one dim, got " << betaTensor->dims.size(); | |||
| return RET_ERROR; | |||
| } | |||
| if (betaTensor->dims != gammaTensor->dims || betaTensor->dims != sigmaTensor->dims || | |||
| betaTensor->dims != muTensor->dims) { | |||
| MS_LOG(ERROR) << "All ConstTensor should have same dims"; | |||
| return RET_ERROR; | |||
| } | |||
| channelOut = betaTensor->dims.front(); | |||
| MS_ASSERT(mulFold != nullptr); | |||
| if (mulFold->inputIndex.size() != 3) { | |||
| MS_LOG(ERROR) << "MulFold node should have 3 outputTensor, got " << addFold->inputIndex.size() << " output tensors"; | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(graph->allTensors.size() > mulFold->inputIndex.front()); | |||
| oldWeightTensor = graph->allTensors.at(mulFold->inputIndex.front()).get(); | |||
| MS_ASSERT(oldWeightTensor != nullptr); | |||
| return RET_OK; | |||
| } | |||
| STATUS BatchNormFoldFusionPass::CheckPath(MetaGraphT *graph, | |||
| const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| MS_ASSERT(preConv != nullptr); | |||
| MS_ASSERT(convNode != nullptr); | |||
| MS_ASSERT(mulFold != nullptr); | |||
| MS_ASSERT(preConv->inputIndex.size() == 2); | |||
| MS_ASSERT(convNode->inputIndex.size() == 2); | |||
| MS_ASSERT(mulFold->inputIndex.size() == 3); | |||
| MS_ASSERT(preConv->inputIndex.front() == convNode->inputIndex.front()); | |||
| MS_ASSERT(preConv->inputIndex.at(1) == mulFold->inputIndex.front()); | |||
| return RET_OK; | |||
| } | |||
| STATUS BatchNormFoldFusionPass::GenNewWeightTensor() { | |||
| MS_ASSERT(oldWeightTensor != nullptr); | |||
| MS_ASSERT(oldWeightTensor->dataType == DataType_DT_FLOAT); | |||
| MS_ASSERT(oldWeightTensor->refCount == schema::NodeType::NodeType_ValueNode); | |||
| auto weightShape = oldWeightTensor->dims; | |||
| if (weightShape.size() != 4) { | |||
| MS_LOG(ERROR) << "shape of weight should be 4 dims, got " << weightShape.size() << " dims"; | |||
| return RET_ERROR; | |||
| } | |||
| if (weightShape.front() != channelOut) { | |||
| MS_LOG(ERROR) << "weight should be in KCHW format, and outputChannel should be " << channelOut; | |||
| return RET_ERROR; | |||
| } | |||
| auto weightShapeSize = GetShapeSize(*oldWeightTensor); | |||
| newWeightTensor = std::unique_ptr<TensorT>(new (std::nothrow) TensorT); | |||
| if (newWeightTensor == nullptr) { | |||
| MS_LOG(ERROR) << "new weightTensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| newWeightTensor->dataType = oldWeightTensor->dataType; | |||
| newWeightTensor->format = oldWeightTensor->format; | |||
| newWeightTensor->refCount = schema::NodeType::NodeType_ValueNode; | |||
| newWeightTensor->dims = weightShape; | |||
| newWeightTensor->data.resize(weightShapeSize * sizeof(float)); | |||
| void *oldWeightData = oldWeightTensor->data.data(); | |||
| auto castedOldWeightData = static_cast<float *>(oldWeightData); | |||
| void *newWeightData = newWeightTensor->data.data(); | |||
| auto castedNewWeightData = static_cast<float *>(newWeightData); | |||
| MS_ASSERT(gammaTensor->dataType == DataType_DT_FLOAT); | |||
| void *gammaData = gammaTensor->data.data(); | |||
| auto *castedGammaData = static_cast<float *>(gammaData); | |||
| MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT); | |||
| void *miData = muTensor->data.data(); | |||
| auto *castedMiData = static_cast<float *>(miData); | |||
| if (channelOut == 0) { | |||
| MS_LOG(ERROR) << "divisor 'channelOut' cannot be 0"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t stride = weightShapeSize / channelOut; | |||
| for (int i = 0; i < channelOut; i++) { | |||
| for (size_t j = 0; j < stride; j++) { | |||
| if (fabs(castedMiData[i]) <= 0.0f) { | |||
| MS_LOG(ERROR) << "divisor 'castedMiData' cannot be 0"; | |||
| return RET_ERROR; | |||
| } | |||
| castedNewWeightData[i * stride + j] = castedOldWeightData[i * stride + j] * castedGammaData[i] / castedMiData[i]; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS BatchNormFoldFusionPass::GenNewBiasTensor() { // bias has no quant | |||
| std::vector<int32_t> biasShape = {channelOut}; | |||
| newBiasTensor = std::unique_ptr<TensorT>(new (std::nothrow) TensorT); | |||
| if (newBiasTensor == nullptr) { | |||
| MS_LOG(ERROR) << "new BiasTensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| newBiasTensor->dataType = 0; | |||
| newBiasTensor->format = schema::Format::Format_NUM_OF_FORMAT; | |||
| newBiasTensor->refCount = schema::NodeType::NodeType_ValueNode; | |||
| newBiasTensor->dims = biasShape; | |||
| newBiasTensor->data.resize(channelOut * sizeof(float)); | |||
| void *newBiasData = newBiasTensor->data.data(); | |||
| auto castedNewBiasData = static_cast<float *>(newBiasData); | |||
| MS_ASSERT(betaTensor->dataType == DataType_DT_FLOAT); | |||
| void *betaData = betaTensor->data.data(); | |||
| auto *castedBetaData = static_cast<float *>(betaData); | |||
| MS_ASSERT(gammaTensor->dataType == DataType_DT_FLOAT); | |||
| void *gammaData = gammaTensor->data.data(); | |||
| auto *castedGammaData = static_cast<float *>(gammaData); | |||
| MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT); | |||
| void *miData = muTensor->data.data(); | |||
| auto *castedMiData = static_cast<float *>(miData); | |||
| MS_ASSERT(sigmaTensor->dataType == DataType_DT_FLOAT); | |||
| void *sigmaData = sigmaTensor->data.data(); | |||
| auto *castedSigmaData = static_cast<float *>(sigmaData); | |||
| for (int i = 0; i < channelOut; i++) { | |||
| if (fabs(castedSigmaData[i]) <= 0.0f) { | |||
| MS_LOG(ERROR) << "divisor 'castedSigmaData' cannot be 0"; | |||
| return RET_ERROR; | |||
| } | |||
| castedNewBiasData[i] = castedBetaData[i] - castedGammaData[i] * castedMiData[i] / castedSigmaData[i]; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS BatchNormFoldFusionPass::IsolateNodes( | |||
| MetaGraphT *graph, const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto preConvPath = matchedPath.at(convPatternOpName1); | |||
| auto bnFoldPath = matchedPath.at(bnFoldOpName); | |||
| auto mulFoldPath = matchedPath.at(mulFoldOpName); | |||
| auto fakeQuantPath = matchedPath.at(fakeQuantOpName); | |||
| auto convPath = matchedPath.at(convPatternOpName2); | |||
| auto addFoldPath = matchedPath.at(addFoldOpName); | |||
| MS_ASSERT(preConvPath != nullptr); | |||
| MS_ASSERT(bnFoldPath != nullptr); | |||
| MS_ASSERT(mulFoldPath != nullptr); | |||
| MS_ASSERT(fakeQuantPath != nullptr); | |||
| MS_ASSERT(convPath != nullptr); | |||
| MS_ASSERT(addFoldPath != nullptr); | |||
| auto status = IsolateOneWayNode(graph, preConvPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode " << preConv->name.c_str() << " failed, error: " << status; | |||
| return status; | |||
| } | |||
| std::vector<uint32_t> toDeleteTensorIdxes; | |||
| toDeleteTensorIdxes.emplace_back(bnFold->inputIndex.at(3)); | |||
| toDeleteTensorIdxes.insert(toDeleteTensorIdxes.end(), bnFold->outputIndex.begin(), bnFold->outputIndex.end()); | |||
| status = RemoveTensor(graph, toDeleteTensorIdxes, true); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Remove Tensors of BnFold " << bnFold->name.c_str() << " failed, error: " << status; | |||
| return RET_ERROR; | |||
| } | |||
| status = IsolateOneWayNode(graph, bnFoldPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode " << bnFold->name.c_str() << " failed, error: " << status; | |||
| return status; | |||
| } | |||
| status = IsolateOneWayNode(graph, mulFoldPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode " << mulFold->name.c_str() << " failed, error: " << status; | |||
| return status; | |||
| } | |||
| status = IsolateOneWayNode(graph, addFoldPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode " << addFold->name.c_str() << " failed, error: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void BatchNormFoldFusionPass::UpdateConvWeights() { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(convNode != nullptr); | |||
| MS_ASSERT(newWeightTensor != nullptr); | |||
| MS_ASSERT(newBiasTensor != nullptr); | |||
| MS_ASSERT(graph->allTensors.size() > fakeNode->inputIndex.at(0)); | |||
| graph->allTensors.at(fakeNode->inputIndex.at(0)).reset(); | |||
| graph->allTensors.at(fakeNode->inputIndex.at(0)) = std::move(this->newWeightTensor); | |||
| graph->allTensors.emplace_back(std::move(this->newBiasTensor)); | |||
| convNode->inputIndex.emplace_back(graph->allTensors.size() - 1); | |||
| if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { | |||
| convNode->primitive->value.AsConv2D()->hasBias = true; | |||
| } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; | |||
| } else { | |||
| MS_ASSERT(false); | |||
| } | |||
| this->oldWeightTensor = nullptr; | |||
| this->newWeightTensor = nullptr; | |||
| this->newBiasTensor = nullptr; | |||
| } | |||
| STATUS BatchNormFoldFusionPass::DeleteConstTensors() { | |||
| MS_ASSERT(graph != nullptr); | |||
| bool muFind = false; | |||
| bool sigmaFind = false; | |||
| bool betaFind = false; | |||
| bool gammaFind = false; | |||
| std::vector<uint32_t> toDeleteTensorIdxes; | |||
| for (size_t i = 0; i < graph->allTensors.size(); i++) { | |||
| auto &tensor = graph->allTensors.at(i); | |||
| if (tensor.get() == muTensor) { | |||
| toDeleteTensorIdxes.emplace_back(i); | |||
| muFind = true; | |||
| this->muTensor = nullptr; | |||
| } | |||
| if (tensor.get() == sigmaTensor) { | |||
| toDeleteTensorIdxes.emplace_back(i); | |||
| sigmaFind = true; | |||
| this->sigmaTensor = nullptr; | |||
| } | |||
| if (tensor.get() == gammaTensor) { | |||
| toDeleteTensorIdxes.emplace_back(i); | |||
| gammaFind = true; | |||
| this->gammaTensor = nullptr; | |||
| } | |||
| if (tensor.get() == betaTensor) { | |||
| toDeleteTensorIdxes.emplace_back(i); | |||
| betaFind = true; | |||
| this->betaTensor = nullptr; | |||
| } | |||
| } | |||
| if (!muFind || !sigmaFind || !betaFind || !gammaFind) { | |||
| MS_LOG(ERROR) << "Can not find muTensor or sigmaTensor or betaTensor or gammaTensor in graph"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = RemoveTensor(graph, toDeleteTensorIdxes); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Remove ConstTensors failed" << bnFold->name.c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| BatchNormFoldFusionPass::~BatchNormFoldFusionPass() { | |||
| if (newWeightTensor == nullptr) { | |||
| newWeightTensor.reset(); | |||
| newWeightTensor = nullptr; | |||
| } | |||
| if (newBiasTensor == nullptr) { | |||
| newBiasTensor.reset(); | |||
| newBiasTensor = nullptr; | |||
| } | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,86 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H | |||
| #define MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| // input = input | |||
| // weight = SimQuantPerChannel(weight * gamma / sigma) | |||
| // bias = beta - gamma * mi / sigma | |||
| // MulFold: gamma sigma | |||
| // BatchNormFold: mi sigma | |||
| // AddFold: gamma beta mi sigma | |||
| class BatchNormFoldFusionPass : public FusionPass { | |||
| public: | |||
| BatchNormFoldFusionPass() = default; | |||
| ~BatchNormFoldFusionPass() override; | |||
| STATUS DefinePattern() override; | |||
| STATUS DoFusion(MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(MetaGraphT *graph) override; | |||
| protected: | |||
| STATUS FindNodes(MetaGraphT *graph, const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath); | |||
| STATUS CheckPath(MetaGraphT *graph, const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath); | |||
| STATUS FindTensors(); | |||
| STATUS GenNewWeightTensor(); | |||
| STATUS GenNewBiasTensor(); | |||
| STATUS IsolateNodes(MetaGraphT *graph, const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath); | |||
| void UpdateConvWeights(); | |||
| STATUS DeleteConstTensors(); | |||
| protected: | |||
| MetaGraphT *graph = nullptr; | |||
| CNodeT *preConv = nullptr; | |||
| CNodeT *bnFold = nullptr; | |||
| CNodeT *mulFold = nullptr; | |||
| CNodeT *fakeNode = nullptr; | |||
| CNodeT *convNode = nullptr; | |||
| CNodeT *addFold = nullptr; | |||
| TensorT *muTensor = nullptr; | |||
| TensorT *sigmaTensor = nullptr; | |||
| TensorT *gammaTensor = nullptr; | |||
| TensorT *betaTensor = nullptr; | |||
| TensorT *oldWeightTensor = nullptr; | |||
| int32_t channelOut = 0; | |||
| std::unique_ptr<TensorT> newWeightTensor = nullptr; | |||
| std::unique_ptr<TensorT> newBiasTensor = nullptr; | |||
| std::string inputOpName = "Input"; | |||
| std::string convPatternOpName1 = "Convolution1"; | |||
| std::string bnFoldOpName = "BatchNormFold"; | |||
| std::string mulFoldOpName = "MulFold"; | |||
| std::string fakeQuantOpName = "FakeQuant"; | |||
| std::string convPatternOpName2 = "Convolution2"; | |||
| std::string addFoldOpName = "AddFold"; | |||
| std::string withPrePatternName = "BNFoldFusionWithPre"; | |||
| std::string noPrePatternName = "BNFoldFusionNoPre"; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_logical_parser.h" | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF LogicalParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tf_op.op() == "LogicalAnd") { | |||
| auto attr = std::make_unique<schema::LogicalAndT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_LogicalAnd; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| } | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOGICAL_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOGICAL_PARSER_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFLogicalParser : public TFNodeParser { | |||
| public: | |||
| TFLogicalParser() = default; | |||
| ~TFLogicalParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOGICAL_PARSER_H_ | |||
| @@ -17,37 +17,57 @@ | |||
| #include "tools/converter/parser/tf/tf_model_parser.h" | |||
| #include <functional> | |||
| #include <regex> | |||
| #include <set> | |||
| #include "src/common/utils.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/protobuf_utils.h" | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| AnfNodePtr TFModelParser::GetAnfNode(const std::string &name) { | |||
| namespace { | |||
| // subgraph node input may be a:output:0/a:z:0 | |||
| std::string GetFlattenNodeName(std::string input_name) { | |||
| std::regex re("\\:+"); | |||
| std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1), | |||
| std::sregex_token_iterator()); | |||
| if (input_splits.size() == 3) { | |||
| if (input_splits[2] == "0") { | |||
| input_name = input_splits[0]; | |||
| } else { | |||
| input_name = input_splits[0] + input_splits[2]; // multi output node | |||
| } | |||
| } | |||
| return input_name; | |||
| } | |||
| AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) { | |||
| AnfNodePtr ret = nullptr; | |||
| if (anf_node_map.find(name) != anf_node_map.end()) { | |||
| ret = anf_node_map[name]; | |||
| ret = anf_node_map.at(name); | |||
| } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { | |||
| ret = anf_node_map[name + ":0"]; | |||
| ret = anf_node_map.at(name + ":0"); | |||
| } | |||
| return ret; | |||
| } | |||
| std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { | |||
| std::string GetOriginInputName(const tensorflow::NodeDef &node, | |||
| const std::map<std::string, const tensorflow::NodeDef *> &tf_graph_nodes) { | |||
| if (node.op() != "Identity" && node.op() != "StopGradient") { | |||
| return node.name(); | |||
| } | |||
| auto tmp_node = &node; | |||
| while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") { | |||
| tmp_node = tf_node_map[tmp_node->input(0)]; | |||
| if (tf_graph_nodes.find(tmp_node->input(0)) == tf_graph_nodes.end()) { | |||
| return tmp_node->input(0); | |||
| } | |||
| tmp_node = tf_graph_nodes.at(tmp_node->input(0)); | |||
| } | |||
| return tmp_node->name(); | |||
| } | |||
| } // namespace | |||
| STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, | |||
| const ParameterPtr ¶meter, std::vector<int64_t> *shape_vector) { | |||
| @@ -126,11 +146,11 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value | |||
| param_value->set_tensor_type(type); | |||
| param_value->set_format(schema::Format::Format_NHWC); | |||
| parameter->set_default_param(param_value); | |||
| parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter) { | |||
| STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_node_map) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(parameter != nullptr); | |||
| @@ -157,8 +177,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa | |||
| return status; | |||
| } | |||
| } else { | |||
| parameter->set_name("placeholder_" + std::to_string(anf_node_map.size())); | |||
| graph_input_names.emplace_back(parameter->name()); | |||
| graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names | |||
| } | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| @@ -166,14 +185,19 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa | |||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| parameter->set_name(node.name()); | |||
| parameter->set_abstract(abstract_tensor); | |||
| anf_node_map[node.name()] = parameter; | |||
| (*anf_node_map)[node.name()] = parameter; | |||
| (*anf_node_map)[node.name() + ":0"] = parameter; | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertGraphInputsAndConsts() { | |||
| for (auto &pair : tf_node_map) { | |||
| STATUS TFModelParser::ConvertGraphInputsAndConsts( | |||
| const std::map<std::string, const tensorflow::NodeDef *> &tf_graph_nodes, const FuncGraphPtr &anf_graph, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_node_map) { | |||
| for (auto &pair : tf_graph_nodes) { | |||
| bool have_data_depend = false; | |||
| for (int i = 0; i < pair.second->input_size(); ++i) { | |||
| auto name = pair.second->input(i); | |||
| @@ -183,8 +207,8 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts() { | |||
| } | |||
| } | |||
| if (!have_data_depend) { | |||
| auto parameter = funcGraphPtr->add_parameter(); | |||
| if (ConvertParameter(*pair.second, parameter) != RET_OK) { | |||
| auto parameter = anf_graph->add_parameter(); | |||
| if (ConvertParameter(*pair.second, parameter, anf_node_map) != RET_OK) { | |||
| MS_LOG(ERROR) << "convert Parameter Node failed"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -192,7 +216,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts() { | |||
| } | |||
| return RET_OK; | |||
| } | |||
| FuncGraphPtr paserTfFuction() { return nullptr; } | |||
| FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| auto status = ValidateFileStr(modelFile, ".pb"); | |||
| @@ -201,51 +225,189 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| tf_graph_def = std::make_unique<tensorflow::GraphDef>(); | |||
| if (tf_graph_def == nullptr) { | |||
| MS_LOG(ERROR) << "tf_graph_def is nullptr"; | |||
| tf_root_graph_ = std::make_unique<tensorflow::GraphDef>(); | |||
| if (tf_root_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "tf_root_graph_ is nullptr"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_graph_def.get()); | |||
| status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| funcGraphPtr = std::make_shared<FuncGraph>(); | |||
| if (funcGraphPtr == nullptr) { | |||
| anf_root_graph_ = std::make_shared<FuncGraph>(); | |||
| if (anf_root_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "funGraphPtr is nullptr"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||
| auto &node_def = tf_graph_def->node(i); | |||
| tf_node_map[node_def.name()] = &node_def; | |||
| for (int i = 0; i < tf_root_graph_->node_size(); i++) { | |||
| auto &node_def = tf_root_graph_->node(i); | |||
| tf_root_graph_nodes_[node_def.name()] = &node_def; | |||
| } | |||
| status = ConvertGraphInputsAndConsts(); | |||
| status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); | |||
| if (status != RET_OK) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| status = ConvertOps(); | |||
| if (status != RET_OK) { | |||
| bool success_flag = true; | |||
| for (int i = 0; i < tf_root_graph_->node_size(); i++) { | |||
| auto &node_def = tf_root_graph_->node(i); | |||
| status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); | |||
| if (status != RET_OK) { | |||
| success_flag = false; | |||
| } | |||
| } | |||
| if (!success_flag) { | |||
| MS_LOG(ERROR) << "Convert ops failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| status = ConvertGraphOutputs(); | |||
| status = ConvertRootGraphOutputs(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert graph outputs failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| return funcGraphPtr; | |||
| status = ConvertSubgraph(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert subgraph failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| return anf_root_graph_; | |||
| } | |||
| STATUS TFModelParser::ConvertSubgraph() { | |||
| auto graph_def_liarary = tf_root_graph_->library(); | |||
| auto subgraph_size = graph_def_liarary.function_size(); | |||
| std::map<CNodePtr, FuncGraphPtr> while_cond_map; | |||
| std::map<CNodePtr, FuncGraphPtr> while_body_map; | |||
| std::vector<ParameterPtr> sub_graph_inputs; | |||
| for (int i = 0; i < subgraph_size; i++) { | |||
| auto &tf_sub_fuction = graph_def_liarary.function(i); | |||
| auto &tf_sub_signature = tf_sub_fuction.signature(); | |||
| auto input_arg_size = tf_sub_signature.input_arg_size(); | |||
| auto &sub_graph_name = tf_sub_signature.name(); | |||
| if (!function_while_map_.count(sub_graph_name)) { | |||
| MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name; | |||
| return RET_ERROR; | |||
| } | |||
| auto while_cnode = function_while_map_[sub_graph_name]->cast<CNodePtr>(); | |||
| if (while_cnode == nullptr || static_cast<int>(while_cnode->inputs().size()) != input_arg_size + 1) { | |||
| MS_LOG(ERROR) << "while cnode not equal input arg size"; | |||
| return RET_ERROR; | |||
| } | |||
| FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>(); | |||
| std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map; | |||
| // convert sub graph inputs | |||
| for (int j = 0; j < input_arg_size; j++) { | |||
| auto &input_arg = tf_sub_signature.input_arg(j); | |||
| auto paramter = sub_func_graph->add_parameter(); | |||
| paramter->set_name(input_arg.name()); | |||
| anf_sub_node_map[input_arg.name()] = paramter; | |||
| sub_graph_inputs.emplace_back(paramter); | |||
| } | |||
| std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map; | |||
| for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) { | |||
| auto &node_def = tf_sub_fuction.node_def(j); | |||
| tf_sub_node_map[node_def.name()] = &node_def; | |||
| } | |||
| STATUS status = RET_OK; | |||
| status = ConvertGraphInputsAndConsts(tf_sub_node_map, sub_func_graph, &anf_sub_node_map); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert subgraph consts failed"; | |||
| return status; | |||
| } | |||
| // convert sub graph ops | |||
| for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) { | |||
| auto &node_def = tf_sub_fuction.node_def(j); | |||
| status = ConvertOps(node_def, tf_sub_node_map, sub_func_graph, &anf_sub_node_map); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert subgraph ops failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| // convert subgraph outputs | |||
| std::vector<AnfNodePtr> sub_output_nodes; | |||
| auto &subgraph_ret = tf_sub_fuction.ret(); | |||
| for (auto &t : subgraph_ret) { | |||
| MS_LOG(INFO) << "subret " << t.first << " " << t.second; | |||
| auto tf_output_name = GetFlattenNodeName(t.second); | |||
| AnfNodePtr anf_node = nullptr; | |||
| if (tf_sub_node_map.find(tf_output_name) == tf_sub_node_map.end()) { | |||
| anf_node = GetAnfNode(tf_output_name, anf_sub_node_map); | |||
| } else { | |||
| auto tf_real_name = GetOriginInputName(*tf_sub_node_map[tf_output_name], tf_sub_node_map); | |||
| anf_node = GetAnfNode(tf_real_name, anf_sub_node_map); | |||
| } | |||
| if (anf_node == nullptr) { | |||
| MS_LOG(ERROR) << "can't find anf node,tf node flatten name" << tf_output_name; | |||
| return RET_ERROR; | |||
| } | |||
| sub_output_nodes.push_back(anf_node); | |||
| } | |||
| status = MakeAnfGraphOutputs(&sub_output_nodes, sub_func_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "cmake anf graph outputs node error"; | |||
| return status; | |||
| } | |||
| // add while cond body function to while node input | |||
| if (sub_graph_name.find("cond") != std::string::npos) { | |||
| while_cond_map[while_cnode] = sub_func_graph; | |||
| } else { | |||
| while_body_map[while_cnode] = sub_func_graph; | |||
| } | |||
| // hardcode subgraph inputs name | |||
| for (size_t j = 0; j < sub_graph_inputs.size(); j++) { | |||
| sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter"); | |||
| } | |||
| MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; | |||
| } | |||
| auto status = WhileNodePostProcess(while_cond_map, while_body_map); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "while node post process failed"; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &while_cond_map, | |||
| const std::map<CNodePtr, FuncGraphPtr> &while_body_map) { | |||
| if (while_cond_map.size() != while_body_map.size()) { | |||
| MS_LOG(ERROR) << "while cond body size error"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<FuncGraphPtr> roots = {anf_root_graph_}; | |||
| auto root_func_manager = std::make_shared<FuncGraphManager>(roots); | |||
| anf_root_graph_->set_manager(root_func_manager); | |||
| for (auto &kv : while_cond_map) { | |||
| auto while_node = kv.first; | |||
| auto &cond_sub_graph = kv.second; | |||
| auto &body_sub_graph = while_body_map.at(while_node); | |||
| cond_sub_graph->set_manager(root_func_manager); | |||
| body_sub_graph->set_manager(root_func_manager); | |||
| auto cond_value_node = NewValueNode(cond_sub_graph); | |||
| auto body_value_node = NewValueNode(body_sub_graph); | |||
| auto new_while_inputs = while_node->cast<CNodePtr>()->inputs(); | |||
| new_while_inputs[0] = cond_value_node; | |||
| new_while_inputs.insert(new_while_inputs.begin() + 1, body_value_node); | |||
| auto new_while_node = anf_root_graph_->NewCNode(new_while_inputs); | |||
| new_while_node->set_abstract(while_node->abstract()); | |||
| root_func_manager->Replace(while_node, new_while_node); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| MS_LOG(ERROR) << "TF Model Parser not return MetaGraph, use TFModelParser::Parse instead"; | |||
| @@ -253,15 +415,21 @@ schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const | |||
| } | |||
| STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, | |||
| const std::vector<std::string> &input_names, std::vector<AnfNodePtr> *inputs) { | |||
| const std::vector<std::string> &input_names, | |||
| const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, | |||
| const std::unordered_map<std::string, AnfNodePtr> &anf_node_map, | |||
| std::vector<AnfNodePtr> *inputs) { | |||
| MS_ASSERT(node_def != nullptr); | |||
| // parse inputs | |||
| for (size_t j = 0; j < input_names.size(); j++) { | |||
| std::string input_name = input_names[j]; // input may be produced by multi-outputs node | |||
| if (tf_node_map.find(input_name) != tf_node_map.end()) { | |||
| auto input_node = tf_node_map[input_name]; | |||
| input_name = GetOriginInputName(*input_node); | |||
| // subgraph input name x:output:index,need flatten | |||
| auto flatten_input_name = GetFlattenNodeName(input_name); | |||
| if (tf_node_map.find(flatten_input_name) != tf_node_map.end()) { | |||
| auto input_node = tf_node_map.at(flatten_input_name); | |||
| flatten_input_name = GetOriginInputName(*input_node, tf_node_map); | |||
| } | |||
| auto input = GetAnfNode(input_name); | |||
| auto input = GetAnfNode(flatten_input_name, anf_node_map); | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes"; | |||
| return RET_ERROR; | |||
| @@ -271,11 +439,16 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size) { | |||
| STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_node_map, | |||
| const FuncGraphPtr &anf_graph, int output_size) { | |||
| MS_ASSERT(op != nullptr); | |||
| MS_ASSERT(anf_node != nullptr); | |||
| MS_ASSERT(anf_graph != nullptr); | |||
| if (output_size == 1) { | |||
| std::vector<int64_t> shape_vector; | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||
| anf_node_map.insert(std::pair(op.name(), anf_node)); | |||
| anf_node_map->insert(std::pair(op.name(), anf_node)); | |||
| } else { | |||
| AbstractBasePtrList abstractList; | |||
| for (int output_idx = 0; output_idx < output_size; output_idx++) { | |||
| @@ -289,104 +462,125 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C | |||
| auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); | |||
| auto getItemValue = NewValueNode(MakeValue<int>(output_idx)); | |||
| std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue}; | |||
| CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); | |||
| CNodePtr getItemCNode = anf_graph->NewCNode(inputs); | |||
| std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | |||
| getItemCNode->set_fullname_with_scope(output_item_name); | |||
| anf_node_map.insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); | |||
| anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); | |||
| } | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertOps() { | |||
| STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, | |||
| const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, | |||
| const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_node_map) { | |||
| MS_ASSERT(node_def != nullptr); | |||
| MS_ASSERT(func_graph_ptr != nullptr); | |||
| NoSupportOp::GetInstance()->SetFmkType("TF"); | |||
| STATUS status = RET_OK; | |||
| int op_idx = 0; | |||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||
| auto &node_def = tf_graph_def->node(i); | |||
| const auto &op_type = node_def.op(); | |||
| if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") { | |||
| continue; | |||
| } | |||
| auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | |||
| if (node_parser == nullptr) { | |||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||
| MS_LOG(ERROR) << "cannot find node parser:" << op_type; | |||
| continue; | |||
| } | |||
| if (status != RET_OK) { | |||
| continue; | |||
| } | |||
| PrimitiveC *primitiveC = nullptr; | |||
| int output_size; | |||
| std::vector<std::string> input_names; | |||
| status = node_parser->Parse(node_def, tf_node_map, &primitiveC, &input_names, &output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "node " << op_type << " parser failed"; | |||
| continue; | |||
| } | |||
| const auto &op_type = node_def.op(); | |||
| if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") { | |||
| return RET_OK; | |||
| } | |||
| auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC)); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "value_node is nullptr"; | |||
| status = RET_ERROR; | |||
| continue; | |||
| } | |||
| std::vector<AnfNodePtr> inputs = {value_node}; | |||
| status = ConvertInputNodes(node_def, input_names, &inputs); | |||
| if (status != RET_OK) { | |||
| continue; | |||
| auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | |||
| if (node_parser == nullptr) { | |||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||
| MS_LOG(ERROR) << "cannot find node parser:" << op_type; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| PrimitiveC *primitiveC = nullptr; | |||
| int output_size; | |||
| std::vector<std::string> input_names; | |||
| status = node_parser->Parse(node_def, tf_node_map, &primitiveC, &input_names, &output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "node " << op_type << " parser failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC)); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "value_node is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<AnfNodePtr> inputs = {value_node}; | |||
| status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| // control_depends are not processed currently | |||
| auto anf_node = func_graph_ptr->NewCNode(inputs); | |||
| anf_node->set_fullname_with_scope(node_def.name()); | |||
| if (op_type == "StatelessWhile" || op_type == "while") { | |||
| MS_LOG(INFO) << "find while node:" << node_def.name(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) { | |||
| auto body_name = attr_value.func().name(); | |||
| function_while_map_[body_name] = anf_node; | |||
| MS_LOG(DEBUG) << "parse body name:" << body_name; | |||
| } | |||
| // control_depends are not processed currently | |||
| auto anf_node = funcGraphPtr->NewCNode(inputs); | |||
| anf_node->set_fullname_with_scope(op_type + "-" + std::to_string(op_idx++)); | |||
| status = ConvertOutputTensor(node_def, anf_node, output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| continue; | |||
| if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { | |||
| auto cond_name = attr_value.func().name(); | |||
| function_while_map_[cond_name] = anf_node; | |||
| MS_LOG(DEBUG) << "parse cond name:" << cond_name; | |||
| } | |||
| } | |||
| status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return RET_ERROR; | |||
| } | |||
| return status; | |||
| } | |||
| STATUS TFModelParser::ConvertGraphOutputs() { | |||
| STATUS TFModelParser::ConvertRootGraphOutputs() { | |||
| // because output of intermediate node in anf graph may also be output tensors, we search output tensors in | |||
| // tf_node_map but not anf_node_map | |||
| // tf_root_graph_nodes_ but not anf_root_node_map_ | |||
| std::set<std::string> all_node_inputs; | |||
| std::vector<AnfNodePtr> output_nodes; | |||
| for (auto &pair : tf_node_map) { | |||
| for (auto &pair : tf_root_graph_nodes_) { | |||
| for (int i = 0; i < pair.second->input_size(); ++i) { | |||
| all_node_inputs.insert(pair.second->input(i)); | |||
| } | |||
| } | |||
| for (auto &pair : tf_node_map) { | |||
| for (auto &pair : tf_root_graph_nodes_) { | |||
| auto it = all_node_inputs.find(pair.first); | |||
| if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity | |||
| auto origin_name = GetOriginInputName(*(pair.second)); | |||
| auto anf_node = GetAnfNode(origin_name); | |||
| auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); | |||
| auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); | |||
| if (anf_node == nullptr) { | |||
| MS_LOG(ERROR) << "can't find anf node"; | |||
| return RET_ERROR; | |||
| } | |||
| output_nodes.push_back(anf_node); | |||
| graph_output_names.push_back(anf_node->fullname_with_scope()); | |||
| graph_output_names_.push_back(anf_node->fullname_with_scope()); | |||
| } | |||
| } | |||
| if (output_nodes.size() > 1) { | |||
| std::vector<AnfNodePtr> &make_tuple_inputs = output_nodes; | |||
| auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "make anf graph outputs node error"; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph) { | |||
| if (output_nodes->empty() || anf_graph == nullptr) { | |||
| MS_LOG(ERROR) << "anf output nodes empty or null anf graph"; | |||
| return RET_ERROR; | |||
| } | |||
| if (output_nodes->size() > 1) { | |||
| std::vector<AnfNodePtr> *make_tuple_inputs = output_nodes; | |||
| auto make_tuple_prim_ptr = GetMakeTuplePrim(); | |||
| if (make_tuple_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | |||
| make_tuple_inputs.insert(output_nodes.begin(), make_tuple_prim); | |||
| auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs); | |||
| make_tuple_inputs->insert(make_tuple_inputs->begin(), make_tuple_prim); | |||
| auto make_tuple_cnode = anf_graph->NewCNode(*make_tuple_inputs); | |||
| make_tuple_cnode->set_fullname_with_scope("return tuple"); | |||
| auto return_prim_ptr = GetReturnPrim(); | |||
| @@ -396,20 +590,20 @@ STATUS TFModelParser::ConvertGraphOutputs() { | |||
| } | |||
| auto value_node = NewValueNode(return_prim_ptr); | |||
| std::vector<AnfNodePtr> op_inputs = {value_node, make_tuple_cnode}; | |||
| auto cnode = funcGraphPtr->NewCNode(op_inputs); | |||
| auto cnode = anf_graph->NewCNode(op_inputs); | |||
| cnode->set_fullname_with_scope("return"); | |||
| funcGraphPtr->set_return(cnode); | |||
| } else if (output_nodes.size() == 1) { | |||
| anf_graph->set_return(cnode); | |||
| } else { | |||
| auto return_prim_ptr = GetReturnPrim(); | |||
| if (return_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto value_node = NewValueNode(return_prim_ptr); | |||
| std::vector<AnfNodePtr> op_inputs{value_node, output_nodes.front()}; | |||
| auto return_cnode = funcGraphPtr->NewCNode(op_inputs); | |||
| std::vector<AnfNodePtr> op_inputs{value_node, output_nodes->front()}; | |||
| auto return_cnode = anf_graph->NewCNode(op_inputs); | |||
| return_cnode->set_fullname_with_scope("return"); | |||
| funcGraphPtr->set_return(return_cnode); | |||
| anf_graph->set_return(return_cnode); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -17,17 +17,17 @@ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "proto/graph.pb.h" | |||
| #include "proto/node_def.pb.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "securec/include/securec.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "proto/node_def.pb.h" | |||
| #include "proto/graph.pb.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -43,24 +43,39 @@ class TFModelParser : public ModelParser { | |||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||
| private: | |||
| AnfNodePtr GetAnfNode(const std::string &name); | |||
| std::string GetOriginInputName(const tensorflow::NodeDef &node); | |||
| STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr ¶meter, | |||
| std::vector<int64_t> *shape_vector); | |||
| STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter); | |||
| STATUS ConvertGraphInputsAndConsts(); | |||
| STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_node_map); | |||
| STATUS ConvertGraphInputsAndConsts(const std::map<std::string, const tensorflow::NodeDef *> &tf_graph_nodes, | |||
| const FuncGraphPtr &anf_graph, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_node_map); | |||
| STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names, | |||
| const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, | |||
| const std::unordered_map<std::string, AnfNodePtr> &anf_node_map, | |||
| std::vector<AnfNodePtr> *inputs); | |||
| STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size); | |||
| STATUS ConvertOps(); | |||
| STATUS ConvertGraphOutputs(); | |||
| STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_node_map, const FuncGraphPtr &anf_graph, | |||
| int output_size); | |||
| STATUS ConvertOps(const tensorflow::NodeDef &node_def, | |||
| const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, | |||
| const FuncGraphPtr &func_graph_ptr, std::unordered_map<std::string, AnfNodePtr> *anf_node_map); | |||
| STATUS ConvertRootGraphOutputs(); | |||
| STATUS ConvertSubgraph(); | |||
| STATUS WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &while_cond_map, | |||
| const std::map<CNodePtr, FuncGraphPtr> &while_body_map); | |||
| STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph); | |||
| FuncGraphPtr funcGraphPtr; | |||
| std::unique_ptr<tensorflow::GraphDef> tf_graph_def; | |||
| std::map<std::string, const tensorflow::NodeDef *> tf_node_map; | |||
| std::unordered_map<std::string, AnfNodePtr> anf_node_map; | |||
| std::vector<std::string> graph_input_names; | |||
| std::vector<std::string> graph_output_names; | |||
| FuncGraphPtr anf_root_graph_; | |||
| std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def | |||
| std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map | |||
| std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_; | |||
| std::vector<std::string> graph_input_names_; | |||
| std::vector<std::string> graph_output_names_; | |||
| std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_while_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFWhileParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF WhileParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::WhileT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_While; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = tf_op.input_size(); | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfStatelessWhileParser("StatelessWhile", new TFWhileParser()); | |||
| TFNodeRegistrar g_tfWhileParser("While", new TFWhileParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_WHILE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_WHILE_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFWhileParser : public TFNodeParser { | |||
| public: | |||
| TFWhileParser() = default; | |||
| ~TFWhileParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_WHILE_PARSER_H_ | |||