Browse Source

!9489 [MSLITE] add tf subgraph parser

From: @zhengjun10
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
cc4f6a30e4
8 changed files with 534 additions and 721 deletions
  1. +0
    -509
      mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.cc
  2. +0
    -86
      mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h
  3. +63
    -0
      mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc
  4. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_logical_parser.h
  5. +301
    -107
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  6. +34
    -19
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h
  7. +62
    -0
      mindspore/lite/tools/converter/parser/tf/tf_while_parser.cc
  8. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_while_parser.h

+ 0
- 509
mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.cc View File

@@ -1,509 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h"
#include <cfloat>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "tools/common/tensor_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {
#define kBatchNormFoldFusionPathLen6 6
#define kBatchNormFoldFusionPathLen7 7

STATUS BatchNormFoldFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); }

STATUS BatchNormFoldFusionPass::DefinePattern() {
// with preNode
{
auto inputOp = std::make_shared<PatternOp>();
inputOp->id = inputOpName;
inputOp->types = {schema::PrimitiveType_NONE};
inputOp->isPlaceHold = true;

auto convOp1 = std::make_shared<PatternOp>();
convOp1->id = convPatternOpName1;
convOp1->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};
convOp1->left = inputOp;

auto bnFoldOp = std::make_shared<PatternOp>();
bnFoldOp->id = bnFoldOpName;
bnFoldOp->types = {schema::PrimitiveType_BatchNormFold};
bnFoldOp->left = convOp1;

auto mulFoldOp = std::make_shared<PatternOp>();
mulFoldOp->id = mulFoldOpName;
mulFoldOp->types = {schema::PrimitiveType_MulFold};
mulFoldOp->left = bnFoldOp;

auto fakeQuantOp = std::make_shared<PatternOp>();
fakeQuantOp->id = fakeQuantOpName;
fakeQuantOp->types = {schema::PrimitiveType_FakeQuantWithMinMax};
fakeQuantOp->left = mulFoldOp;

auto convOp2 = std::make_shared<PatternOp>();
convOp2->id = convPatternOpName2;
convOp2->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};
convOp2->left = fakeQuantOp;
convOp2->right = inputOp;

auto addFoldOp = std::make_shared<PatternOp>();
addFoldOp->id = addFoldOpName;
addFoldOp->types = {schema::PrimitiveType_AddFold};
addFoldOp->left = convOp2;
addFoldOp->right = bnFoldOp;

std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern(withPrePatternName));
if (fusionPattern == nullptr) {
MS_LOG(ERROR) << "new fusionPattern failed";
return RET_ERROR;
}
fusionPattern->AddPatternOp(inputOp);
fusionPattern->AddPatternOp(convOp1);
fusionPattern->AddPatternOp(bnFoldOp);
fusionPattern->AddPatternOp(mulFoldOp);
fusionPattern->AddPatternOp(fakeQuantOp);
fusionPattern->AddPatternOp(convOp2);
fusionPattern->AddPatternOp(addFoldOp);
fusionPattern->Finish();

this->patterns.emplace_back(fusionPattern.release());
}
// no preNode
{
auto convOp1 = std::make_shared<PatternOp>();
convOp1->id = convPatternOpName1;
convOp1->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};

auto bnFoldOp = std::make_shared<PatternOp>();
bnFoldOp->id = bnFoldOpName;
bnFoldOp->types = {schema::PrimitiveType_BatchNormFold};
bnFoldOp->left = convOp1;

auto mulFoldOp = std::make_shared<PatternOp>();
mulFoldOp->id = mulFoldOpName;
mulFoldOp->types = {schema::PrimitiveType_MulFold};
mulFoldOp->left = bnFoldOp;

auto fakeQuantOp = std::make_shared<PatternOp>();
fakeQuantOp->id = fakeQuantOpName;
fakeQuantOp->types = {schema::PrimitiveType_FakeQuantWithMinMax};
fakeQuantOp->left = mulFoldOp;

auto convOp2 = std::make_shared<PatternOp>();
convOp2->id = convPatternOpName2;
convOp2->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};
convOp2->left = fakeQuantOp;

auto addFoldOp = std::make_shared<PatternOp>();
addFoldOp->id = addFoldOpName;
addFoldOp->types = {schema::PrimitiveType_AddFold};
addFoldOp->left = convOp2;
addFoldOp->right = bnFoldOp;

std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern(noPrePatternName));
if (fusionPattern == nullptr) {
MS_LOG(ERROR) << "new fusionPattern failed";
return RET_ERROR;
}
fusionPattern->AddPatternOp(convOp1);
fusionPattern->AddPatternOp(bnFoldOp);
fusionPattern->AddPatternOp(mulFoldOp);
fusionPattern->AddPatternOp(fakeQuantOp);
fusionPattern->AddPatternOp(convOp2);
fusionPattern->AddPatternOp(addFoldOp);
fusionPattern->Finish();

this->patterns.emplace_back(fusionPattern.release());
}
return RET_OK;
}

STATUS BatchNormFoldFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
if (patternName == withPrePatternName) {
if (matchedPath.size() != kBatchNormFoldFusionPathLen7) {
MS_LOG(ERROR) << "BatchNormFold-Fusion should have seven NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}
} else if (patternName == noPrePatternName) {
if (matchedPath.size() != kBatchNormFoldFusionPathLen6) {
MS_LOG(ERROR) << "BatchNormFold-Fusion should have six NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}
}

auto status = FindNodes(graph, matchedPath);
if (status != RET_OK) {
MS_LOG(ERROR) << "FindNodes failed: " << status;
return status;
}
status = CheckPath(graph, matchedPath);
if (status != RET_OK) {
MS_LOG(ERROR) << "CheckPath failed: " << status;
return status;
}
status = FindTensors();
if (status != RET_OK) {
MS_LOG(ERROR) << "FindTensors failed: " << status;
return status;
}
status = GenNewWeightTensor();
if (status != RET_OK) {
MS_LOG(ERROR) << "GenNewWeightTensor failed: " << status;
return status;
}
status = GenNewBiasTensor();
if (status != RET_OK) {
MS_LOG(ERROR) << "GenNewBiasTensor failed: " << status;
return status;
}
status = IsolateNodes(graph, matchedPath);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateNodes failed: " << status;
return status;
}
UpdateConvWeights();
status = DeleteConstTensors();
if (status != RET_OK) {
MS_LOG(ERROR) << "DeleteConstTensors failed: " << status;
return status;
}
return RET_OK;
}

STATUS BatchNormFoldFusionPass::FindNodes(MetaGraphT *graph,
const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
auto preConvPath = matchedPath.at(convPatternOpName1);
auto bnFoldPath = matchedPath.at(bnFoldOpName);
auto mulFoldPath = matchedPath.at(mulFoldOpName);
auto fakeQuantPath = matchedPath.at(fakeQuantOpName);
auto convPath = matchedPath.at(convPatternOpName2);
auto addFoldPath = matchedPath.at(addFoldOpName);
MS_ASSERT(preConvPath != nullptr);
MS_ASSERT(bnFoldPath != nullptr);
MS_ASSERT(mulFoldPath != nullptr);
MS_ASSERT(fakeQuantPath != nullptr);
MS_ASSERT(convPath != nullptr);
MS_ASSERT(addFoldPath != nullptr);
if (preConvPath->subGraphIdx != bnFoldPath->subGraphIdx || preConvPath->subGraphIdx != mulFoldPath->subGraphIdx ||
preConvPath->subGraphIdx != fakeQuantPath->subGraphIdx || preConvPath->subGraphIdx != convPath->subGraphIdx ||
preConvPath->subGraphIdx != addFoldPath->subGraphIdx) {
MS_LOG(ERROR) << "matched nodes should from same subGraph";
return RET_ERROR;
}
MS_ASSERT(graph->nodes.size() > preConvPath->nodeIdx);
MS_ASSERT(graph->nodes.size() > bnFoldPath->nodeIdx);
MS_ASSERT(graph->nodes.size() > mulFoldPath->nodeIdx);
MS_ASSERT(graph->nodes.size() > fakeQuantPath->nodeIdx);
MS_ASSERT(graph->nodes.size() > convPath->nodeIdx);
MS_ASSERT(graph->nodes.size() > addFoldPath->nodeIdx);
preConv = graph->nodes.at(preConvPath->nodeIdx).get();
bnFold = graph->nodes.at(bnFoldPath->nodeIdx).get();
mulFold = graph->nodes.at(mulFoldPath->nodeIdx).get();
fakeNode = graph->nodes.at(fakeQuantPath->nodeIdx).get();
convNode = graph->nodes.at(convPath->nodeIdx).get();
addFold = graph->nodes.at(addFoldPath->nodeIdx).get();
MS_ASSERT(preConv != nullptr);
MS_ASSERT(bnFold != nullptr);
MS_ASSERT(mulFold != nullptr);
MS_ASSERT(fakeNode != nullptr);
MS_ASSERT(convNode != nullptr);
MS_ASSERT(addFold != nullptr);
return RET_OK;
}

STATUS BatchNormFoldFusionPass::FindTensors() {
MS_ASSERT(graph != nullptr);
MS_ASSERT(bnFold != nullptr);
MS_ASSERT(addFold != nullptr);
if (bnFold->inputIndex.size() != 4) {
MS_LOG(ERROR) << "BatchNormFold node should have 4 inputTensor, got " << bnFold->inputIndex.size()
<< " input tensors";
return RET_ERROR;
}
if (addFold->inputIndex.size() != 5) {
MS_LOG(ERROR) << "AddFold node should have 5 inputTensor, got " << addFold->inputIndex.size() << " input tensors";
return RET_ERROR;
}
MS_ASSERT(graph->allTensors.size() > bnFold->inputIndex.at(1));
muTensor = graph->allTensors.at(bnFold->inputIndex.at(1)).get();
MS_ASSERT(muTensor != nullptr);
MS_ASSERT(graph->allTensors.size() > bnFold->inputIndex.at(2));
sigmaTensor = graph->allTensors.at(bnFold->inputIndex.at(2)).get();
MS_ASSERT(sigmaTensor != nullptr);
MS_ASSERT(graph->allTensors.size() > addFold->inputIndex.at(1));
betaTensor = graph->allTensors.at(addFold->inputIndex.at(1)).get();
MS_ASSERT(betaTensor != nullptr);
MS_ASSERT(graph->allTensors.size() > addFold->inputIndex.at(2));
gammaTensor = graph->allTensors.at(addFold->inputIndex.at(2)).get();
MS_ASSERT(gammaTensor != nullptr);

if (betaTensor->dims.size() != 1) {
MS_LOG(ERROR) << "ConstTensor should have only one dim, got " << betaTensor->dims.size();
return RET_ERROR;
}
if (betaTensor->dims != gammaTensor->dims || betaTensor->dims != sigmaTensor->dims ||
betaTensor->dims != muTensor->dims) {
MS_LOG(ERROR) << "All ConstTensor should have same dims";
return RET_ERROR;
}
channelOut = betaTensor->dims.front();

MS_ASSERT(mulFold != nullptr);
if (mulFold->inputIndex.size() != 3) {
MS_LOG(ERROR) << "MulFold node should have 3 outputTensor, got " << addFold->inputIndex.size() << " output tensors";
return RET_ERROR;
}
MS_ASSERT(graph->allTensors.size() > mulFold->inputIndex.front());
oldWeightTensor = graph->allTensors.at(mulFold->inputIndex.front()).get();
MS_ASSERT(oldWeightTensor != nullptr);
return RET_OK;
}

STATUS BatchNormFoldFusionPass::CheckPath(MetaGraphT *graph,
const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(preConv != nullptr);
MS_ASSERT(convNode != nullptr);
MS_ASSERT(mulFold != nullptr);
MS_ASSERT(preConv->inputIndex.size() == 2);
MS_ASSERT(convNode->inputIndex.size() == 2);
MS_ASSERT(mulFold->inputIndex.size() == 3);
MS_ASSERT(preConv->inputIndex.front() == convNode->inputIndex.front());
MS_ASSERT(preConv->inputIndex.at(1) == mulFold->inputIndex.front());
return RET_OK;
}

STATUS BatchNormFoldFusionPass::GenNewWeightTensor() {
MS_ASSERT(oldWeightTensor != nullptr);
MS_ASSERT(oldWeightTensor->dataType == DataType_DT_FLOAT);
MS_ASSERT(oldWeightTensor->refCount == schema::NodeType::NodeType_ValueNode);
auto weightShape = oldWeightTensor->dims;
if (weightShape.size() != 4) {
MS_LOG(ERROR) << "shape of weight should be 4 dims, got " << weightShape.size() << " dims";
return RET_ERROR;
}
if (weightShape.front() != channelOut) {
MS_LOG(ERROR) << "weight should be in KCHW format, and outputChannel should be " << channelOut;
return RET_ERROR;
}
auto weightShapeSize = GetShapeSize(*oldWeightTensor);
newWeightTensor = std::unique_ptr<TensorT>(new (std::nothrow) TensorT);
if (newWeightTensor == nullptr) {
MS_LOG(ERROR) << "new weightTensor failed";
return RET_ERROR;
}
newWeightTensor->dataType = oldWeightTensor->dataType;
newWeightTensor->format = oldWeightTensor->format;
newWeightTensor->refCount = schema::NodeType::NodeType_ValueNode;
newWeightTensor->dims = weightShape;
newWeightTensor->data.resize(weightShapeSize * sizeof(float));
void *oldWeightData = oldWeightTensor->data.data();
auto castedOldWeightData = static_cast<float *>(oldWeightData);
void *newWeightData = newWeightTensor->data.data();
auto castedNewWeightData = static_cast<float *>(newWeightData);
MS_ASSERT(gammaTensor->dataType == DataType_DT_FLOAT);
void *gammaData = gammaTensor->data.data();
auto *castedGammaData = static_cast<float *>(gammaData);
MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT);
void *miData = muTensor->data.data();
auto *castedMiData = static_cast<float *>(miData);
if (channelOut == 0) {
MS_LOG(ERROR) << "divisor 'channelOut' cannot be 0";
return RET_ERROR;
}
size_t stride = weightShapeSize / channelOut;
for (int i = 0; i < channelOut; i++) {
for (size_t j = 0; j < stride; j++) {
if (fabs(castedMiData[i]) <= 0.0f) {
MS_LOG(ERROR) << "divisor 'castedMiData' cannot be 0";
return RET_ERROR;
}
castedNewWeightData[i * stride + j] = castedOldWeightData[i * stride + j] * castedGammaData[i] / castedMiData[i];
}
}
return RET_OK;
}

STATUS BatchNormFoldFusionPass::GenNewBiasTensor() { // bias has no quant
std::vector<int32_t> biasShape = {channelOut};
newBiasTensor = std::unique_ptr<TensorT>(new (std::nothrow) TensorT);
if (newBiasTensor == nullptr) {
MS_LOG(ERROR) << "new BiasTensor failed";
return RET_ERROR;
}
newBiasTensor->dataType = 0;
newBiasTensor->format = schema::Format::Format_NUM_OF_FORMAT;
newBiasTensor->refCount = schema::NodeType::NodeType_ValueNode;
newBiasTensor->dims = biasShape;
newBiasTensor->data.resize(channelOut * sizeof(float));
void *newBiasData = newBiasTensor->data.data();
auto castedNewBiasData = static_cast<float *>(newBiasData);
MS_ASSERT(betaTensor->dataType == DataType_DT_FLOAT);
void *betaData = betaTensor->data.data();
auto *castedBetaData = static_cast<float *>(betaData);
MS_ASSERT(gammaTensor->dataType == DataType_DT_FLOAT);
void *gammaData = gammaTensor->data.data();
auto *castedGammaData = static_cast<float *>(gammaData);
MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT);
void *miData = muTensor->data.data();
auto *castedMiData = static_cast<float *>(miData);
MS_ASSERT(sigmaTensor->dataType == DataType_DT_FLOAT);
void *sigmaData = sigmaTensor->data.data();
auto *castedSigmaData = static_cast<float *>(sigmaData);
for (int i = 0; i < channelOut; i++) {
if (fabs(castedSigmaData[i]) <= 0.0f) {
MS_LOG(ERROR) << "divisor 'castedSigmaData' cannot be 0";
return RET_ERROR;
}
castedNewBiasData[i] = castedBetaData[i] - castedGammaData[i] * castedMiData[i] / castedSigmaData[i];
}
return RET_OK;
}

STATUS BatchNormFoldFusionPass::IsolateNodes(
MetaGraphT *graph, const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
auto preConvPath = matchedPath.at(convPatternOpName1);
auto bnFoldPath = matchedPath.at(bnFoldOpName);
auto mulFoldPath = matchedPath.at(mulFoldOpName);
auto fakeQuantPath = matchedPath.at(fakeQuantOpName);
auto convPath = matchedPath.at(convPatternOpName2);
auto addFoldPath = matchedPath.at(addFoldOpName);
MS_ASSERT(preConvPath != nullptr);
MS_ASSERT(bnFoldPath != nullptr);
MS_ASSERT(mulFoldPath != nullptr);
MS_ASSERT(fakeQuantPath != nullptr);
MS_ASSERT(convPath != nullptr);
MS_ASSERT(addFoldPath != nullptr);
auto status = IsolateOneWayNode(graph, preConvPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode " << preConv->name.c_str() << " failed, error: " << status;
return status;
}
std::vector<uint32_t> toDeleteTensorIdxes;
toDeleteTensorIdxes.emplace_back(bnFold->inputIndex.at(3));
toDeleteTensorIdxes.insert(toDeleteTensorIdxes.end(), bnFold->outputIndex.begin(), bnFold->outputIndex.end());
status = RemoveTensor(graph, toDeleteTensorIdxes, true);
if (status != RET_OK) {
MS_LOG(ERROR) << "Remove Tensors of BnFold " << bnFold->name.c_str() << " failed, error: " << status;
return RET_ERROR;
}
status = IsolateOneWayNode(graph, bnFoldPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode " << bnFold->name.c_str() << " failed, error: " << status;
return status;
}
status = IsolateOneWayNode(graph, mulFoldPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode " << mulFold->name.c_str() << " failed, error: " << status;
return status;
}
status = IsolateOneWayNode(graph, addFoldPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode " << addFold->name.c_str() << " failed, error: " << status;
return status;
}
return RET_OK;
}

void BatchNormFoldFusionPass::UpdateConvWeights() {
MS_ASSERT(graph != nullptr);
MS_ASSERT(convNode != nullptr);
MS_ASSERT(newWeightTensor != nullptr);
MS_ASSERT(newBiasTensor != nullptr);
MS_ASSERT(graph->allTensors.size() > fakeNode->inputIndex.at(0));
graph->allTensors.at(fakeNode->inputIndex.at(0)).reset();
graph->allTensors.at(fakeNode->inputIndex.at(0)) = std::move(this->newWeightTensor);
graph->allTensors.emplace_back(std::move(this->newBiasTensor));
convNode->inputIndex.emplace_back(graph->allTensors.size() - 1);
if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) {
convNode->primitive->value.AsConv2D()->hasBias = true;
} else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true;
} else {
MS_ASSERT(false);
}

this->oldWeightTensor = nullptr;
this->newWeightTensor = nullptr;
this->newBiasTensor = nullptr;
}

STATUS BatchNormFoldFusionPass::DeleteConstTensors() {
MS_ASSERT(graph != nullptr);
bool muFind = false;
bool sigmaFind = false;
bool betaFind = false;
bool gammaFind = false;
std::vector<uint32_t> toDeleteTensorIdxes;
for (size_t i = 0; i < graph->allTensors.size(); i++) {
auto &tensor = graph->allTensors.at(i);
if (tensor.get() == muTensor) {
toDeleteTensorIdxes.emplace_back(i);
muFind = true;
this->muTensor = nullptr;
}
if (tensor.get() == sigmaTensor) {
toDeleteTensorIdxes.emplace_back(i);
sigmaFind = true;
this->sigmaTensor = nullptr;
}
if (tensor.get() == gammaTensor) {
toDeleteTensorIdxes.emplace_back(i);
gammaFind = true;
this->gammaTensor = nullptr;
}
if (tensor.get() == betaTensor) {
toDeleteTensorIdxes.emplace_back(i);
betaFind = true;
this->betaTensor = nullptr;
}
}
if (!muFind || !sigmaFind || !betaFind || !gammaFind) {
MS_LOG(ERROR) << "Can not find muTensor or sigmaTensor or betaTensor or gammaTensor in graph";
return RET_ERROR;
}
auto status = RemoveTensor(graph, toDeleteTensorIdxes);
if (status != RET_OK) {
MS_LOG(ERROR) << "Remove ConstTensors failed" << bnFold->name.c_str();
return RET_ERROR;
}
return RET_OK;
}

BatchNormFoldFusionPass::~BatchNormFoldFusionPass() {
if (newWeightTensor == nullptr) {
newWeightTensor.reset();
newWeightTensor = nullptr;
}
if (newBiasTensor == nullptr) {
newBiasTensor.reset();
newBiasTensor = nullptr;
}
}
} // namespace lite
} // namespace mindspore

+ 0
- 86
mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h View File

@@ -1,86 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H
#define MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H

#include <unordered_map>
#include <memory>
#include <string>
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"

namespace mindspore {
namespace lite {
// input = input
// weight = SimQuantPerChannel(weight * gamma / sigma)
// bias = beta - gamma * mi / sigma
// MulFold: gamma sigma
// BatchNormFold: mi sigma
// AddFold: gamma beta mi sigma
class BatchNormFoldFusionPass : public FusionPass {
public:
BatchNormFoldFusionPass() = default;

~BatchNormFoldFusionPass() override;

STATUS DefinePattern() override;

STATUS DoFusion(MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(MetaGraphT *graph) override;

protected:
STATUS FindNodes(MetaGraphT *graph, const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath);
STATUS CheckPath(MetaGraphT *graph, const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath);
STATUS FindTensors();
STATUS GenNewWeightTensor();
STATUS GenNewBiasTensor();
STATUS IsolateNodes(MetaGraphT *graph, const std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath);
void UpdateConvWeights();
STATUS DeleteConstTensors();

protected:
MetaGraphT *graph = nullptr;
CNodeT *preConv = nullptr;
CNodeT *bnFold = nullptr;
CNodeT *mulFold = nullptr;
CNodeT *fakeNode = nullptr;
CNodeT *convNode = nullptr;
CNodeT *addFold = nullptr;
TensorT *muTensor = nullptr;
TensorT *sigmaTensor = nullptr;
TensorT *gammaTensor = nullptr;
TensorT *betaTensor = nullptr;
TensorT *oldWeightTensor = nullptr;
int32_t channelOut = 0;

std::unique_ptr<TensorT> newWeightTensor = nullptr;
std::unique_ptr<TensorT> newBiasTensor = nullptr;

std::string inputOpName = "Input";
std::string convPatternOpName1 = "Convolution1";
std::string bnFoldOpName = "BatchNormFold";
std::string mulFoldOpName = "MulFold";
std::string fakeQuantOpName = "FakeQuant";
std::string convPatternOpName2 = "Convolution2";
std::string addFoldOpName = "AddFold";
std::string withPrePatternName = "BNFoldFusionWithPre";
std::string noPrePatternName = "BNFoldFusionNoPre";
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H

+ 63
- 0
mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc View File

@@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_logical_parser.h"
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF LogicalParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
if (tf_op.op() == "LogicalAnd") {
auto attr = std::make_unique<schema::LogicalAndT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_LogicalAnd;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
}
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

return RET_OK;
}
TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_logical_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOGICAL_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOGICAL_PARSER_H_

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFLogicalParser : public TFNodeParser {
public:
TFLogicalParser() = default;
~TFLogicalParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOGICAL_PARSER_H_

+ 301
- 107
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -17,37 +17,57 @@


#include "tools/converter/parser/tf/tf_model_parser.h" #include "tools/converter/parser/tf/tf_model_parser.h"
#include <functional> #include <functional>
#include <regex>
#include <set> #include <set>
#include "src/common/utils.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "src/common/utils.h"
#include "src/param_value_lite.h" #include "src/param_value_lite.h"
#include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h" #include "tools/common/protobuf_utils.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

AnfNodePtr TFModelParser::GetAnfNode(const std::string &name) {
namespace {
// subgraph node input may be a:output:0/a:z:0
std::string GetFlattenNodeName(std::string input_name) {
std::regex re("\\:+");
std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1),
std::sregex_token_iterator());
if (input_splits.size() == 3) {
if (input_splits[2] == "0") {
input_name = input_splits[0];
} else {
input_name = input_splits[0] + input_splits[2]; // multi output node
}
}
return input_name;
}
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) {
AnfNodePtr ret = nullptr; AnfNodePtr ret = nullptr;
if (anf_node_map.find(name) != anf_node_map.end()) { if (anf_node_map.find(name) != anf_node_map.end()) {
ret = anf_node_map[name];
ret = anf_node_map.at(name);
} else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) {
ret = anf_node_map[name + ":0"];
ret = anf_node_map.at(name + ":0");
} }
return ret; return ret;
} }


std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) {
std::string GetOriginInputName(const tensorflow::NodeDef &node,
const std::map<std::string, const tensorflow::NodeDef *> &tf_graph_nodes) {
if (node.op() != "Identity" && node.op() != "StopGradient") { if (node.op() != "Identity" && node.op() != "StopGradient") {
return node.name(); return node.name();
} }
auto tmp_node = &node; auto tmp_node = &node;
while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") { while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") {
tmp_node = tf_node_map[tmp_node->input(0)];
if (tf_graph_nodes.find(tmp_node->input(0)) == tf_graph_nodes.end()) {
return tmp_node->input(0);
}
tmp_node = tf_graph_nodes.at(tmp_node->input(0));
} }
return tmp_node->name(); return tmp_node->name();
} }
} // namespace


STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type,
const ParameterPtr &parameter, std::vector<int64_t> *shape_vector) { const ParameterPtr &parameter, std::vector<int64_t> *shape_vector) {
@@ -126,11 +146,11 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value
param_value->set_tensor_type(type); param_value->set_tensor_type(type);
param_value->set_format(schema::Format::Format_NHWC); param_value->set_format(schema::Format::Format_NHWC);
parameter->set_default_param(param_value); parameter->set_default_param(param_value);
parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter");
return RET_OK; return RET_OK;
} }


STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr &parameter) {
STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr &parameter,
std::unordered_map<std::string, AnfNodePtr> *anf_node_map) {
MS_ASSERT(node != nullptr); MS_ASSERT(node != nullptr);
MS_ASSERT(parameter != nullptr); MS_ASSERT(parameter != nullptr);


@@ -157,8 +177,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
return status; return status;
} }
} else { } else {
parameter->set_name("placeholder_" + std::to_string(anf_node_map.size()));
graph_input_names.emplace_back(parameter->name());
graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names
} }


auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
@@ -166,14 +185,19 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
MS_LOG(ERROR) << "abstract_tensor is nullptr"; MS_LOG(ERROR) << "abstract_tensor is nullptr";
return RET_ERROR; return RET_ERROR;
} }
parameter->set_name(node.name());
parameter->set_abstract(abstract_tensor); parameter->set_abstract(abstract_tensor);


anf_node_map[node.name()] = parameter;
(*anf_node_map)[node.name()] = parameter;
(*anf_node_map)[node.name() + ":0"] = parameter;

return RET_OK; return RET_OK;
} }


STATUS TFModelParser::ConvertGraphInputsAndConsts() {
for (auto &pair : tf_node_map) {
STATUS TFModelParser::ConvertGraphInputsAndConsts(
const std::map<std::string, const tensorflow::NodeDef *> &tf_graph_nodes, const FuncGraphPtr &anf_graph,
std::unordered_map<std::string, AnfNodePtr> *anf_node_map) {
for (auto &pair : tf_graph_nodes) {
bool have_data_depend = false; bool have_data_depend = false;
for (int i = 0; i < pair.second->input_size(); ++i) { for (int i = 0; i < pair.second->input_size(); ++i) {
auto name = pair.second->input(i); auto name = pair.second->input(i);
@@ -183,8 +207,8 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts() {
} }
} }
if (!have_data_depend) { if (!have_data_depend) {
auto parameter = funcGraphPtr->add_parameter();
if (ConvertParameter(*pair.second, parameter) != RET_OK) {
auto parameter = anf_graph->add_parameter();
if (ConvertParameter(*pair.second, parameter, anf_node_map) != RET_OK) {
MS_LOG(ERROR) << "convert Parameter Node failed"; MS_LOG(ERROR) << "convert Parameter Node failed";
return RET_ERROR; return RET_ERROR;
} }
@@ -192,7 +216,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts() {
} }
return RET_OK; return RET_OK;
} }
FuncGraphPtr paserTfFuction() { return nullptr; }
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) { const QuantType &quantType) {
auto status = ValidateFileStr(modelFile, ".pb"); auto status = ValidateFileStr(modelFile, ".pb");
@@ -201,51 +225,189 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }
tf_graph_def = std::make_unique<tensorflow::GraphDef>();
if (tf_graph_def == nullptr) {
MS_LOG(ERROR) << "tf_graph_def is nullptr";
tf_root_graph_ = std::make_unique<tensorflow::GraphDef>();
if (tf_root_graph_ == nullptr) {
MS_LOG(ERROR) << "tf_root_graph_ is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr; return nullptr;
} }
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_graph_def.get());
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get());
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr; return nullptr;
} }
funcGraphPtr = std::make_shared<FuncGraph>();
if (funcGraphPtr == nullptr) {
anf_root_graph_ = std::make_shared<FuncGraph>();
if (anf_root_graph_ == nullptr) {
MS_LOG(ERROR) << "funGraphPtr is nullptr"; MS_LOG(ERROR) << "funGraphPtr is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr; return nullptr;
} }


for (int i = 0; i < tf_graph_def->node_size(); i++) {
auto &node_def = tf_graph_def->node(i);
tf_node_map[node_def.name()] = &node_def;
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
tf_root_graph_nodes_[node_def.name()] = &node_def;
} }


status = ConvertGraphInputsAndConsts();
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
if (status != RET_OK) { if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }

status = ConvertOps();
if (status != RET_OK) {
bool success_flag = true;
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
if (status != RET_OK) {
success_flag = false;
}
}
if (!success_flag) {
MS_LOG(ERROR) << "Convert ops failed."; MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }

status = ConvertGraphOutputs();
status = ConvertRootGraphOutputs();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed."; MS_LOG(ERROR) << "Convert graph outputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }
return funcGraphPtr;

status = ConvertSubgraph();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert subgraph failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}

return anf_root_graph_;
}
STATUS TFModelParser::ConvertSubgraph() {
auto graph_def_liarary = tf_root_graph_->library();
auto subgraph_size = graph_def_liarary.function_size();
std::map<CNodePtr, FuncGraphPtr> while_cond_map;
std::map<CNodePtr, FuncGraphPtr> while_body_map;
std::vector<ParameterPtr> sub_graph_inputs;
for (int i = 0; i < subgraph_size; i++) {
auto &tf_sub_fuction = graph_def_liarary.function(i);
auto &tf_sub_signature = tf_sub_fuction.signature();
auto input_arg_size = tf_sub_signature.input_arg_size();

auto &sub_graph_name = tf_sub_signature.name();
if (!function_while_map_.count(sub_graph_name)) {
MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name;
return RET_ERROR;
}
auto while_cnode = function_while_map_[sub_graph_name]->cast<CNodePtr>();
if (while_cnode == nullptr || static_cast<int>(while_cnode->inputs().size()) != input_arg_size + 1) {
MS_LOG(ERROR) << "while cnode not equal input arg size";
return RET_ERROR;
}

FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>();
std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map;
// convert sub graph inputs
for (int j = 0; j < input_arg_size; j++) {
auto &input_arg = tf_sub_signature.input_arg(j);
auto paramter = sub_func_graph->add_parameter();
paramter->set_name(input_arg.name());
anf_sub_node_map[input_arg.name()] = paramter;
sub_graph_inputs.emplace_back(paramter);
}
std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map;
for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) {
auto &node_def = tf_sub_fuction.node_def(j);
tf_sub_node_map[node_def.name()] = &node_def;
}
STATUS status = RET_OK;
status = ConvertGraphInputsAndConsts(tf_sub_node_map, sub_func_graph, &anf_sub_node_map);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert subgraph consts failed";
return status;
}
// convert sub graph ops
for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) {
auto &node_def = tf_sub_fuction.node_def(j);
status = ConvertOps(node_def, tf_sub_node_map, sub_func_graph, &anf_sub_node_map);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert subgraph ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
}

// convert subgraph outputs
std::vector<AnfNodePtr> sub_output_nodes;
auto &subgraph_ret = tf_sub_fuction.ret();
for (auto &t : subgraph_ret) {
MS_LOG(INFO) << "subret " << t.first << " " << t.second;
auto tf_output_name = GetFlattenNodeName(t.second);
AnfNodePtr anf_node = nullptr;
if (tf_sub_node_map.find(tf_output_name) == tf_sub_node_map.end()) {
anf_node = GetAnfNode(tf_output_name, anf_sub_node_map);
} else {
auto tf_real_name = GetOriginInputName(*tf_sub_node_map[tf_output_name], tf_sub_node_map);
anf_node = GetAnfNode(tf_real_name, anf_sub_node_map);
}
if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node,tf node flatten name" << tf_output_name;
return RET_ERROR;
}
sub_output_nodes.push_back(anf_node);
}
status = MakeAnfGraphOutputs(&sub_output_nodes, sub_func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "cmake anf graph outputs node error";
return status;
}

// add while cond body function to while node input
if (sub_graph_name.find("cond") != std::string::npos) {
while_cond_map[while_cnode] = sub_func_graph;
} else {
while_body_map[while_cnode] = sub_func_graph;
}
// hardcode subgraph inputs name
for (size_t j = 0; j < sub_graph_inputs.size(); j++) {
sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter");
}
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name;
}
auto status = WhileNodePostProcess(while_cond_map, while_body_map);
if (status != RET_OK) {
MS_LOG(ERROR) << "while node post process failed";
return status;
}
return RET_OK;
} }
STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &while_cond_map,
const std::map<CNodePtr, FuncGraphPtr> &while_body_map) {
if (while_cond_map.size() != while_body_map.size()) {
MS_LOG(ERROR) << "while cond body size error";
return RET_ERROR;
}
std::vector<FuncGraphPtr> roots = {anf_root_graph_};
auto root_func_manager = std::make_shared<FuncGraphManager>(roots);
anf_root_graph_->set_manager(root_func_manager);
for (auto &kv : while_cond_map) {
auto while_node = kv.first;
auto &cond_sub_graph = kv.second;
auto &body_sub_graph = while_body_map.at(while_node);
cond_sub_graph->set_manager(root_func_manager);
body_sub_graph->set_manager(root_func_manager);
auto cond_value_node = NewValueNode(cond_sub_graph);
auto body_value_node = NewValueNode(body_sub_graph);
auto new_while_inputs = while_node->cast<CNodePtr>()->inputs();
new_while_inputs[0] = cond_value_node;
new_while_inputs.insert(new_while_inputs.begin() + 1, body_value_node);
auto new_while_node = anf_root_graph_->NewCNode(new_while_inputs);
new_while_node->set_abstract(while_node->abstract());
root_func_manager->Replace(while_node, new_while_node);
}
return RET_OK;
}

schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) { const QuantType &quantType) {
MS_LOG(ERROR) << "TF Model Parser not return MetaGraph, use TFModelParser::Parse instead"; MS_LOG(ERROR) << "TF Model Parser not return MetaGraph, use TFModelParser::Parse instead";
@@ -253,15 +415,21 @@ schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const
} }


STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def,
const std::vector<std::string> &input_names, std::vector<AnfNodePtr> *inputs) {
const std::vector<std::string> &input_names,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
std::vector<AnfNodePtr> *inputs) {
MS_ASSERT(node_def != nullptr);
// parse inputs // parse inputs
for (size_t j = 0; j < input_names.size(); j++) { for (size_t j = 0; j < input_names.size(); j++) {
std::string input_name = input_names[j]; // input may be produced by multi-outputs node std::string input_name = input_names[j]; // input may be produced by multi-outputs node
if (tf_node_map.find(input_name) != tf_node_map.end()) {
auto input_node = tf_node_map[input_name];
input_name = GetOriginInputName(*input_node);
// subgraph input name x:output:index,need flatten
auto flatten_input_name = GetFlattenNodeName(input_name);
if (tf_node_map.find(flatten_input_name) != tf_node_map.end()) {
auto input_node = tf_node_map.at(flatten_input_name);
flatten_input_name = GetOriginInputName(*input_node, tf_node_map);
} }
auto input = GetAnfNode(input_name);
auto input = GetAnfNode(flatten_input_name, anf_node_map);
if (input == nullptr) { if (input == nullptr) {
MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes"; MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes";
return RET_ERROR; return RET_ERROR;
@@ -271,11 +439,16 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def,
return RET_OK; return RET_OK;
} }


STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size) {
STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node,
std::unordered_map<std::string, AnfNodePtr> *anf_node_map,
const FuncGraphPtr &anf_graph, int output_size) {
MS_ASSERT(op != nullptr);
MS_ASSERT(anf_node != nullptr);
MS_ASSERT(anf_graph != nullptr);
if (output_size == 1) { if (output_size == 1) {
std::vector<int64_t> shape_vector; std::vector<int64_t> shape_vector;
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
anf_node_map.insert(std::pair(op.name(), anf_node));
anf_node_map->insert(std::pair(op.name(), anf_node));
} else { } else {
AbstractBasePtrList abstractList; AbstractBasePtrList abstractList;
for (int output_idx = 0; output_idx < output_size; output_idx++) { for (int output_idx = 0; output_idx < output_size; output_idx++) {
@@ -289,104 +462,125 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr);
auto getItemValue = NewValueNode(MakeValue<int>(output_idx)); auto getItemValue = NewValueNode(MakeValue<int>(output_idx));
std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue}; std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue};
CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs);
CNodePtr getItemCNode = anf_graph->NewCNode(inputs);
std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
getItemCNode->set_fullname_with_scope(output_item_name); getItemCNode->set_fullname_with_scope(output_item_name);
anf_node_map.insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode));
anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode));
} }
anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList)); anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList));
} }
return RET_OK; return RET_OK;
} }


STATUS TFModelParser::ConvertOps() {
STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_node_map) {
MS_ASSERT(node_def != nullptr);
MS_ASSERT(func_graph_ptr != nullptr);
NoSupportOp::GetInstance()->SetFmkType("TF"); NoSupportOp::GetInstance()->SetFmkType("TF");
STATUS status = RET_OK; STATUS status = RET_OK;
int op_idx = 0;
for (int i = 0; i < tf_graph_def->node_size(); i++) {
auto &node_def = tf_graph_def->node(i);
const auto &op_type = node_def.op();
if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") {
continue;
}
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(op_type);
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
MS_LOG(ERROR) << "cannot find node parser:" << op_type;
continue;
}
if (status != RET_OK) {
continue;
}
PrimitiveC *primitiveC = nullptr;
int output_size;
std::vector<std::string> input_names;
status = node_parser->Parse(node_def, tf_node_map, &primitiveC, &input_names, &output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "node " << op_type << " parser failed";
continue;
}
const auto &op_type = node_def.op();
if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") {
return RET_OK;
}


auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC));
if (value_node == nullptr) {
MS_LOG(ERROR) << "value_node is nullptr";
status = RET_ERROR;
continue;
}
std::vector<AnfNodePtr> inputs = {value_node};
status = ConvertInputNodes(node_def, input_names, &inputs);
if (status != RET_OK) {
continue;
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(op_type);
MS_LOG(ERROR) << "cannot find node parser:" << op_type;
return RET_NOT_FIND_OP;
}
PrimitiveC *primitiveC = nullptr;
int output_size;
std::vector<std::string> input_names;
status = node_parser->Parse(node_def, tf_node_map, &primitiveC, &input_names, &output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "node " << op_type << " parser failed";
return RET_ERROR;
}
auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC));
if (value_node == nullptr) {
MS_LOG(ERROR) << "value_node is nullptr";
return RET_ERROR;
}
std::vector<AnfNodePtr> inputs = {value_node};
status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs);
if (status != RET_OK) {
return status;
}
// control_depends are not processed currently
auto anf_node = func_graph_ptr->NewCNode(inputs);
anf_node->set_fullname_with_scope(node_def.name());
if (op_type == "StatelessWhile" || op_type == "while") {
MS_LOG(INFO) << "find while node:" << node_def.name();
tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) {
auto body_name = attr_value.func().name();
function_while_map_[body_name] = anf_node;
MS_LOG(DEBUG) << "parse body name:" << body_name;
} }
// control_depends are not processed currently
auto anf_node = funcGraphPtr->NewCNode(inputs);
anf_node->set_fullname_with_scope(op_type + "-" + std::to_string(op_idx++));

status = ConvertOutputTensor(node_def, anf_node, output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
continue;
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) {
auto cond_name = attr_value.func().name();
function_while_map_[cond_name] = anf_node;
MS_LOG(DEBUG) << "parse cond name:" << cond_name;
} }
} }

status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
return status; return status;
} }


STATUS TFModelParser::ConvertGraphOutputs() {
STATUS TFModelParser::ConvertRootGraphOutputs() {
// because output of intermediate node in anf graph may also be output tensors, we search output tensors in // because output of intermediate node in anf graph may also be output tensors, we search output tensors in
// tf_node_map but not anf_node_map
// tf_root_graph_nodes_ but not anf_root_node_map_
std::set<std::string> all_node_inputs; std::set<std::string> all_node_inputs;
std::vector<AnfNodePtr> output_nodes; std::vector<AnfNodePtr> output_nodes;
for (auto &pair : tf_node_map) {
for (auto &pair : tf_root_graph_nodes_) {
for (int i = 0; i < pair.second->input_size(); ++i) { for (int i = 0; i < pair.second->input_size(); ++i) {
all_node_inputs.insert(pair.second->input(i)); all_node_inputs.insert(pair.second->input(i));
} }
} }
for (auto &pair : tf_node_map) {
for (auto &pair : tf_root_graph_nodes_) {
auto it = all_node_inputs.find(pair.first); auto it = all_node_inputs.find(pair.first);
if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity
auto origin_name = GetOriginInputName(*(pair.second));
auto anf_node = GetAnfNode(origin_name);
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_);
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_);
if (anf_node == nullptr) { if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node"; MS_LOG(ERROR) << "can't find anf node";
return RET_ERROR; return RET_ERROR;
} }
output_nodes.push_back(anf_node); output_nodes.push_back(anf_node);
graph_output_names.push_back(anf_node->fullname_with_scope());
graph_output_names_.push_back(anf_node->fullname_with_scope());
} }
} }

if (output_nodes.size() > 1) {
std::vector<AnfNodePtr> &make_tuple_inputs = output_nodes;
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_);
if (status != RET_OK) {
MS_LOG(ERROR) << "make anf graph outputs node error";
return status;
}
return RET_OK;
}
STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph) {
if (output_nodes->empty() || anf_graph == nullptr) {
MS_LOG(ERROR) << "anf output nodes empty or null anf graph";
return RET_ERROR;
}
if (output_nodes->size() > 1) {
std::vector<AnfNodePtr> *make_tuple_inputs = output_nodes;
auto make_tuple_prim_ptr = GetMakeTuplePrim(); auto make_tuple_prim_ptr = GetMakeTuplePrim();
if (make_tuple_prim_ptr == nullptr) { if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
make_tuple_inputs.insert(output_nodes.begin(), make_tuple_prim);
auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs);
make_tuple_inputs->insert(make_tuple_inputs->begin(), make_tuple_prim);
auto make_tuple_cnode = anf_graph->NewCNode(*make_tuple_inputs);
make_tuple_cnode->set_fullname_with_scope("return tuple"); make_tuple_cnode->set_fullname_with_scope("return tuple");


auto return_prim_ptr = GetReturnPrim(); auto return_prim_ptr = GetReturnPrim();
@@ -396,20 +590,20 @@ STATUS TFModelParser::ConvertGraphOutputs() {
} }
auto value_node = NewValueNode(return_prim_ptr); auto value_node = NewValueNode(return_prim_ptr);
std::vector<AnfNodePtr> op_inputs = {value_node, make_tuple_cnode}; std::vector<AnfNodePtr> op_inputs = {value_node, make_tuple_cnode};
auto cnode = funcGraphPtr->NewCNode(op_inputs);
auto cnode = anf_graph->NewCNode(op_inputs);
cnode->set_fullname_with_scope("return"); cnode->set_fullname_with_scope("return");
funcGraphPtr->set_return(cnode);
} else if (output_nodes.size() == 1) {
anf_graph->set_return(cnode);
} else {
auto return_prim_ptr = GetReturnPrim(); auto return_prim_ptr = GetReturnPrim();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr"; MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
auto value_node = NewValueNode(return_prim_ptr); auto value_node = NewValueNode(return_prim_ptr);
std::vector<AnfNodePtr> op_inputs{value_node, output_nodes.front()};
auto return_cnode = funcGraphPtr->NewCNode(op_inputs);
std::vector<AnfNodePtr> op_inputs{value_node, output_nodes->front()};
auto return_cnode = anf_graph->NewCNode(op_inputs);
return_cnode->set_fullname_with_scope("return"); return_cnode->set_fullname_with_scope("return");
funcGraphPtr->set_return(return_cnode);
anf_graph->set_return(return_cnode);
} }
return RET_OK; return RET_OK;
} }


+ 34
- 19
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h View File

@@ -17,17 +17,17 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H


#include <string>
#include <vector>
#include <memory>
#include <map> #include <map>
#include <memory>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "proto/graph.pb.h"
#include "proto/node_def.pb.h"
#include "schema/inner/model_generated.h"
#include "securec/include/securec.h" #include "securec/include/securec.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "tools/converter/model_parser.h" #include "tools/converter/model_parser.h"
#include "schema/inner/model_generated.h"
#include "proto/node_def.pb.h"
#include "proto/graph.pb.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@@ -43,24 +43,39 @@ class TFModelParser : public ModelParser {
const QuantType &quantType = QuantType_QUANT_NONE) override; const QuantType &quantType = QuantType_QUANT_NONE) override;


private: private:
AnfNodePtr GetAnfNode(const std::string &name);
std::string GetOriginInputName(const tensorflow::NodeDef &node);
STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr &parameter, STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr &parameter,
std::vector<int64_t> *shape_vector); std::vector<int64_t> *shape_vector);
STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr &parameter);
STATUS ConvertGraphInputsAndConsts();
STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr &parameter,
std::unordered_map<std::string, AnfNodePtr> *anf_node_map);
STATUS ConvertGraphInputsAndConsts(const std::map<std::string, const tensorflow::NodeDef *> &tf_graph_nodes,
const FuncGraphPtr &anf_graph,
std::unordered_map<std::string, AnfNodePtr> *anf_node_map);
STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names, STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
std::vector<AnfNodePtr> *inputs); std::vector<AnfNodePtr> *inputs);
STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size);
STATUS ConvertOps();
STATUS ConvertGraphOutputs();
STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node,
std::unordered_map<std::string, AnfNodePtr> *anf_node_map, const FuncGraphPtr &anf_graph,
int output_size);
STATUS ConvertOps(const tensorflow::NodeDef &node_def,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const FuncGraphPtr &func_graph_ptr, std::unordered_map<std::string, AnfNodePtr> *anf_node_map);
STATUS ConvertRootGraphOutputs();

STATUS ConvertSubgraph();

STATUS WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &while_cond_map,
const std::map<CNodePtr, FuncGraphPtr> &while_body_map);

STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph);


FuncGraphPtr funcGraphPtr;
std::unique_ptr<tensorflow::GraphDef> tf_graph_def;
std::map<std::string, const tensorflow::NodeDef *> tf_node_map;
std::unordered_map<std::string, AnfNodePtr> anf_node_map;
std::vector<std::string> graph_input_names;
std::vector<std::string> graph_output_names;
FuncGraphPtr anf_root_graph_;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;
std::vector<std::string> graph_input_names_;
std::vector<std::string> graph_output_names_;
std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 62
- 0
mindspore/lite/tools/converter/parser/tf/tf_while_parser.cc View File

@@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_while_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFWhileParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF WhileParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::WhileT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}

primitive->value.type = schema::PrimitiveType_While;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfStatelessWhileParser("StatelessWhile", new TFWhileParser());
TFNodeRegistrar g_tfWhileParser("While", new TFWhileParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_while_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_WHILE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_WHILE_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFWhileParser : public TFNodeParser {
public:
TFWhileParser() = default;
~TFWhileParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_WHILE_PARSER_H_

Loading…
Cancel
Save