Browse Source

set root tensor for merge

tags/v1.1.0
mengyuanli 5 years ago
parent
commit
95211c8fce
6 changed files with 202 additions and 121 deletions
  1. +0
    -1
      mindspore/lite/test/models_tf.cfg
  2. +58
    -19
      mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc
  3. +6
    -3
      mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h
  4. +132
    -88
      mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc
  5. +6
    -6
      mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h
  6. +0
    -4
      mindspore/lite/tools/optimizer/graph/infershape_pass.cc

+ 0
- 1
mindspore/lite/test/models_tf.cfg View File

@@ -1 +0,0 @@
decoder_step_201217.pb 5

+ 58
- 19
mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc View File

@@ -28,30 +28,39 @@
namespace mindspore {
namespace lite {

std::set<uint32_t> SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
schema::MetaGraphT *graph) {
std::set<uint32_t> tensors_indices{};
STATUS SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
schema::MetaGraphT *graph, std::set<uint32_t> *tensors_indices) {
for (auto &node_idx : subgraph->nodeIndices) {
if (node_idx >= graph->nodes.size()) {
MS_LOG(ERROR) << "node_idx: " << node_idx << " bigger than graph->nodes.size(): " << graph->nodes.size();
for (auto &subgraph : graph->subGraph) {
MS_LOG(ERROR) << subgraph->name << " : " << subgraph->nodeIndices;
}
return RET_ERROR;
}
auto &node = graph->nodes.at(node_idx);
for (auto &input_idx : node->inputIndex) {
tensors_indices.insert(input_idx);
tensors_indices->insert(input_idx);
}
for (auto &output_idx : node->outputIndex) {
tensors_indices.insert(output_idx);
tensors_indices->insert(output_idx);
}
}
return tensors_indices;
return RET_OK;
}

bool SubgraphNodePass::IsNodeInputInSubgraph(const std::set<uint32_t> &tensors_indices,
const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph) {
return std::any_of(node->inputIndex.begin(), node->inputIndex.end(),
[&tensors_indices, &subgraph](uint32_t idx) { return tensors_indices.count(idx) > 0; });
}

bool SubgraphNodePass::IsNodeInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph) {
return (std::any_of(node->inputIndex.begin(), node->inputIndex.end(),
[&tensors_indices, &subgraph](uint32_t idx) {
return tensors_indices.count(idx) > 0 || IsContain(subgraph->inputIndices, idx);
})) &&
(std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
return tensors_indices.count(idx) > 0 || IsContain(subgraph->outputIndices, idx);
}));
bool SubgraphNodePass::IsNodeOutputInSubgraph(const std::set<uint32_t> &tensors_indices,
const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph) {
return std::any_of(node->outputIndex.begin(), node->outputIndex.end(),
[&tensors_indices, &subgraph](uint32_t idx) { return tensors_indices.count(idx) > 0; });
}

void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
@@ -104,12 +113,42 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
for (uint32_t i = 0; i < new_nodes.size(); i++) {
if (!IsContain(old_nodes_, new_nodes[i])) {
auto &node = graph->nodes.at(i);
std::vector<SubGraphT *> contain_node_input_subgraphs{};
std::vector<SubGraphT *> contain_node_output_subgraphs{};
for (auto &subgraph : graph->subGraph) {
auto tensors_indices = GetSubgraphAllTensorIndices(subgraph, graph);
if (IsNodeInSubgraph(tensors_indices, node, subgraph)) {
IncreaseSubgraphNodeIndices(i, graph);
subgraph->nodeIndices.push_back(i);
std::set<uint32_t> tensors_indices{};
int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetSubgraphAllTensorIndices failed.";
return ret;
}
if (IsNodeInputInSubgraph(tensors_indices, node, subgraph)) {
contain_node_input_subgraphs.push_back(subgraph.get());
}
if (IsNodeOutputInSubgraph(tensors_indices, node, subgraph)) {
contain_node_output_subgraphs.push_back(subgraph.get());
}
}
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 &&
contain_node_output_subgraphs[0] != contain_node_input_subgraphs[0]) {
MS_LOG(ERROR) << "not support single node index insert.";
return RET_ERROR;
}
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 &&
contain_node_output_subgraphs[0] == contain_node_input_subgraphs[0]) {
IncreaseSubgraphNodeIndices(i, graph);
contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
continue;
}
if (contain_node_input_subgraphs.size() == 1) {
IncreaseSubgraphNodeIndices(i, graph);
contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
continue;
}
if (contain_node_output_subgraphs.size() == 1) {
IncreaseSubgraphNodeIndices(i, graph);
contain_node_output_subgraphs[0]->nodeIndices.push_back(i);
continue;
}
}
}


+ 6
- 3
mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h View File

@@ -36,9 +36,12 @@ class SubgraphNodePass : public GraphPass {
private:
void DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
void IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
std::set<uint32_t> GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph, schema::MetaGraphT *graph);
bool IsNodeInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph);
STATUS GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph, schema::MetaGraphT *graph,
std::set<uint32_t> *tensors_indices);
bool IsNodeInputInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph);
bool IsNodeOutputInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph);
std::vector<schema::CNodeT *> old_nodes_;
};
} // namespace lite


+ 132
- 88
mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc View File

@@ -16,6 +16,7 @@

#include <vector>
#include <map>
#include <set>
#include <algorithm>
#include "tools/converter/legacy_optimizer/graph/switch_pass.h"
#include "src/common/log_adapter.h"
@@ -47,7 +48,10 @@ STATUS SwitchPass::Run(mindspore::schema::MetaGraphT *graph) {

STATUS SingleSwitchPass::DoubleSwitchOutput() {
origin_switch_output_tensor_indices_ = switch_node_->outputIndex;
MS_ASSERT(origin_switch_output_tensor_indices_.size() == cond_partial_node_->inputIndex.szie());
if (origin_switch_output_tensor_indices_.size() != cond_partial_node_->inputIndex.size()) {
MS_LOG(ERROR) << "switch node: " << switch_node_->name << " input or output number is not right.";
return RET_ERROR;
}
for (size_t i = 0; i < origin_switch_output_tensor_indices_.size(); i++) {
auto &switch_out_tensor = graph_->allTensors.at(origin_switch_output_tensor_indices_[i]);
const auto &cond_partial_input_tensor = graph_->allTensors.at(cond_partial_node_->inputIndex[i]);
@@ -60,7 +64,7 @@ STATUS SingleSwitchPass::DoubleSwitchOutput() {
return RET_OK;
}

void SingleSwitchPass::DoubleIdx(uint32_t *idx) {
void SingleSwitchPass::UpdateSwitchOutputIndices(uint32_t *idx) {
auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), *idx);
if (iter != switch_node_->outputIndex.end()) {
int pos = iter - switch_node_->outputIndex.begin();
@@ -69,25 +73,21 @@ void SingleSwitchPass::DoubleIdx(uint32_t *idx) {
}

STATUS SingleSwitchPass::UpdateSwitchUser() {
std::vector<CNodeT *> switch_users;
for (auto &node_idx : graph_->subGraph.at(this_subgraph_index_)->nodeIndices) {
auto &node = graph_->nodes.at(node_idx);
for (auto &idx : node->inputIndex) {
if (IsContain(switch_node_->outputIndex, idx)) {
switch_users.push_back(node.get());
}
DoubleIdx(&idx);
UpdateSwitchOutputIndices(&idx);
}
}
// update graph switch user
for (auto &subgraph : graph_->subGraph) {
for (auto &idx : subgraph->outputIndices) {
DoubleIdx(&idx);
UpdateSwitchOutputIndices(&idx);
}
}

for (auto &idx : graph_->outputIndex) {
DoubleIdx(&idx);
UpdateSwitchOutputIndices(&idx);
}

return RET_OK;
@@ -104,20 +104,71 @@ bool SingleSwitchPass::IsLoop() {
return false;
}

std::unique_ptr<schema::TensorT> SingleSwitchPass::NewTensor(const std::unique_ptr<schema::TensorT> &in_tensor) {
std::unique_ptr<schema::TensorT> SingleSwitchPass::NewTensor(const std::unique_ptr<schema::TensorT> &in_tensor,
bool with_data) {
auto out_tensor = std::make_unique<schema::TensorT>();
out_tensor->nodeType = in_tensor->nodeType;
out_tensor->dims = in_tensor->dims;
out_tensor->dataType = in_tensor->dataType;
out_tensor->data = in_tensor->data;
out_tensor->format = in_tensor->format;
if (with_data) {
out_tensor->data = in_tensor->data;
}
return out_tensor;
}

STATUS SingleSwitchPass::BodyGraphVariableInput(std::vector<size_t> *variable_input) {
auto &body_fg = graph_->subGraph.at(body_subgraph_index_);
auto body_fg_output = body_fg->outputIndices;
for (auto &subgraph_output : body_fg_output) {
for (auto &node : body_graph_nodes_) {
if (node != nullptr && IsContain(node->outputIndex, subgraph_output)) {
int partial_idx = GetSubgraphOutputTensorIndex(body_fg, node);
if (partial_idx == -1) {
MS_LOG(ERROR) << "get input index failed.";
return RET_ERROR;
}
(*variable_input).emplace_back(partial_idx);
}
}
}
return RET_OK;
}

STATUS SingleSwitchPass::InsertMerge() {
int ret = RET_OK;
// update body graph output
auto &body_fg = graph_->subGraph.at(body_subgraph_index_);
body_fg->outputIndices.assign(body_to_cond_partial_node_->inputIndex.begin(),
body_to_cond_partial_node_->inputIndex.end());

// remove body_to_cond_partial_node_ from body_graph_nodes_
for (auto it = body_graph_nodes_.begin(); it != body_graph_nodes_.end();) {
if (*it == body_to_cond_partial_node_) {
it = body_graph_nodes_.erase(it);
} else {
it++;
}
}

// isolate body_to_cond_partial_node_
IsolateUselessNode(body_to_cond_partial_node_, graph_);

std::vector<size_t> variable_input{};
int ret = BodyGraphVariableInput(&variable_input);
if (ret != RET_OK) {
MS_LOG(ERROR) << "get body graph variable input failed, ret: " << ret;
return ret;
}

std::vector<size_t> const_input{};
for (size_t i = 0; i < body_partial_node_->inputIndex.size(); i++) {
if (IsContain(variable_input, i)) {
continue;
}
const_input.push_back(i);
}

auto merge_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
MS_ASSERT(merge_node != nullptr);
auto primitiveT = std::unique_ptr<PrimitiveT>(new (std::nothrow) PrimitiveT);
MS_ASSERT(primitiveT != nullptr);
merge_node->primitive = std::move(primitiveT);
@@ -129,8 +180,6 @@ STATUS SingleSwitchPass::InsertMerge() {
MS_ASSERT(merge_param != nullptr);
merge_node->primitive->value.value = merge_param.release();

merge_node->inputIndex.assign(cond_partial_node_->inputIndex.begin(), cond_partial_node_->inputIndex.end());

// merge node output is same as switch
for (auto &out_index : origin_switch_output_tensor_indices_) {
auto &switch_out_tensor = graph_->allTensors.at(out_index);
@@ -139,12 +188,30 @@ STATUS SingleSwitchPass::InsertMerge() {
merge_node->outputIndex.push_back(graph_->allTensors.size() - 1);
}

// double merge inputs to contain the outputs of body node
for (auto &index : cond_partial_node_->inputIndex) {
auto &in_tensor = graph_->allTensors.at(index);
auto tensor = NewTensor(in_tensor);
graph_->allTensors.push_back(std::move(tensor));
merge_node->inputIndex.push_back(graph_->allTensors.size() - 1);
merge_node->inputIndex.assign(cond_partial_node_->inputIndex.begin(), cond_partial_node_->inputIndex.end());

std::set<uint32_t> input_set{};
for (auto &iter : merge_node->inputIndex) {
if (input_set.find(iter) != input_set.end()) {
auto &in_tensor = graph_->allTensors.at(iter);
auto tensor = NewTensor(in_tensor, true);
graph_->allTensors.push_back(std::move(tensor));
iter = graph_->allTensors.size() - 1;
}
input_set.insert(iter);
}

// double merge inputs to contain the outputs of body node
auto old_merge_input = merge_node->inputIndex;
for (size_t i = 0; i < old_merge_input.size(); i++) {
auto &in_tensor = graph_->allTensors.at(old_merge_input[i]);
if (IsContain(const_input, i)) {
merge_node->inputIndex.push_back(old_merge_input[i]);
} else {
auto tensor = NewTensor(in_tensor);
graph_->allTensors.push_back(std::move(tensor));
merge_node->inputIndex.push_back(graph_->allTensors.size() - 1);
}
}

// insert merge node before the cond graph
@@ -182,46 +249,12 @@ STATUS SingleSwitchPass::InsertMerge() {
graph_->nodes.push_back(std::move(merge_node));
graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1);

// update bodu graph output
graph_->subGraph.at(body_subgraph_index_)
->outputIndices.assign(body_to_cond_partial_node_->inputIndex.begin(),
body_to_cond_partial_node_->inputIndex.end());

// erase body_to_cond_partial_node_
RemoveUselessNode(body_to_cond_partial_node_, graph_);
return ret;
return RET_OK;
}

void SingleSwitchPass::RemoveUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph) {
void SingleSwitchPass::IsolateUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph) {
partial_node->inputIndex.clear();
partial_node->outputIndex.clear();

int pos = -1;
for (size_t i = 0; i < graph->nodes.size(); ++i) {
if (graph->nodes.at(i).get() == partial_node) {
pos = i;
break;
}
}

if (pos == -1) {
return;
}

graph->nodes.erase(graph->nodes.begin() + pos);

for (auto &subgraph : graph->subGraph) {
for (auto it = subgraph->nodeIndices.begin(); it != subgraph->nodeIndices.end();) {
if (*it == static_cast<uint32_t>(pos)) {
it = subgraph->nodeIndices.erase(it);
} else {
if (*it > static_cast<uint32_t>(pos)) {
(*it)--;
}
it++;
}
}
}
}

size_t SingleSwitchPass::InitThisGraphIndex() {
@@ -265,12 +298,10 @@ STATUS SingleSwitchPass::Init() {
for (auto &out_index : iter->get()->outputIndex) {
if (out_index == switch_node_->inputIndex[kSwitchCondIndex]) {
cond_partial_node_ = iter->get();
cond_node_index_ = iter - graph_->nodes.begin();
find_cond_node = true;
}
if (out_index == switch_node_->inputIndex[kSwitchBodyIndex]) {
body_partial_node_ = iter->get();
body_node_index_ = iter - graph_->nodes.begin();
find_body_node = true;
}
}
@@ -301,6 +332,41 @@ STATUS SingleSwitchPass::Init() {
return RET_OK;
}

int SingleSwitchPass::GetSubgraphInputTensorIndex(const std::unique_ptr<SubGraphT> &subgraph,
const std::unique_ptr<TensorT> &tensor) {
int partial_idx = -1;
if (tensor->name.find("_input_") != std::string::npos) {
// get parameter input index k. subgraph name + “_input_" + "k"
auto pos = subgraph->name.size() + sizeof("_input_");
auto pos2 = tensor->name.find('_', pos);
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1);
partial_idx = std::stoi(idx_str);
}

if (tensor->name.find("_output_") != std::string::npos) {
// get parameter input index k. subgraph name + “_output_" + "k"
auto pos = subgraph->name.size() + sizeof("_output_");
auto pos2 = tensor->name.find('_', pos);
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1);
partial_idx = std::stoi(idx_str);
}
return partial_idx;
}

int SingleSwitchPass::GetSubgraphOutputTensorIndex(const std::unique_ptr<SubGraphT> &subgraph, CNodeT *node) {
int partial_idx = -1;
if (node->name == "LogicalAnd") {
partial_idx = 0;
} else {
// get parameter input index k. subgraph name + “_output_" + "k"
auto pos = subgraph->name.size() + sizeof("_output_");
auto pos2 = node->name.find('_', pos);
auto idx_str = node->name.substr(pos - 1, pos2 - pos + 1);
partial_idx = std::stoi(idx_str);
}
return partial_idx;
}

STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node,
const std::vector<schema::CNodeT *> &subgraph_nodes) {
if (partial_node == nullptr || subgraph_nodes.empty()) {
@@ -315,27 +381,11 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
std::vector<std::pair<int, int>> tmp_inputs_order{};
for (unsigned int &subgraph_input : subgraph_inputs) {
auto &tensor = graph_->allTensors.at(subgraph_input);
if (tensor->name.size() < subgraph->name.size() + 8) {
MS_LOG(ERROR) << "tensor name: " << tensor->name << " not right.";
int partial_idx = GetSubgraphInputTensorIndex(subgraph, tensor);
if (partial_idx == -1) {
MS_LOG(ERROR) << "get input index failed.";
return RET_ERROR;
}
int partial_idx = -1;
if (tensor->name.find("_input_") != std::string::npos) {
// get parameter input index k. subgraph name + “_input_" + "k"
auto pos = subgraph->name.size() + sizeof("_input_");
auto pos2 = tensor->name.find('_', pos);
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1);
partial_idx = std::stoi(idx_str);
}

if (tensor->name.find("_output_") != std::string::npos) {
// get parameter input index k. subgraph name + “_output_" + "k"
auto pos = subgraph->name.size() + sizeof("_output_");
auto pos2 = tensor->name.find('_', pos);
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1);
partial_idx = std::stoi(idx_str);
}

subgraph_input_map.insert(std::pair<int, int>{subgraph_input, partial_inputs[partial_idx]});
tmp_inputs_order.emplace_back(partial_idx, partial_inputs[partial_idx]);
}
@@ -374,15 +424,10 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche
for (unsigned int &subgraph_output : subgraph_outputs) {
for (auto &node : subgraph_nodes) {
if (IsContain(node->outputIndex, subgraph_output)) {
int partial_idx = -1;
if (node->name == "LogicalAnd") {
partial_idx = 0;
} else {
// get parameter input index k. subgraph name + “_output_" + "k"
auto pos = subgraph->name.size() + sizeof("_output_");
auto pos2 = node->name.find('_', pos);
auto idx_str = node->name.substr(pos - 1, pos2);
partial_idx = std::stoi(idx_str);
int partial_idx = GetSubgraphOutputTensorIndex(subgraph, node);
if (partial_idx == -1) {
MS_LOG(ERROR) << "get input index failed.";
return RET_ERROR;
}
subgraph_output_map.insert(std::pair<int, int>{subgraph_output, partial_outputs[partial_idx]});
tmp_outputs_order.emplace_back(partial_idx, partial_outputs[partial_idx]);
@@ -473,7 +518,6 @@ STATUS SingleSwitchPass::Run() {
MS_LOG(ERROR) << "ConcatBodySubgraphInputAndOutput failed, ret: " << ret;
return ret;
}

return RET_OK;
}
} // namespace mindspore::lite

+ 6
- 6
mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h View File

@@ -50,13 +50,16 @@ class SingleSwitchPass {
STATUS ConcatBodySubgraphInputAndOutput();
bool IsLoop();
STATUS InsertMerge();
int GetSubgraphInputTensorIndex(const std::unique_ptr<SubGraphT> &subgraph, const std::unique_ptr<TensorT> &tensor);
int GetSubgraphOutputTensorIndex(const std::unique_ptr<SubGraphT> &subgraph, CNodeT *node);
STATUS UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node,
const std::vector<schema::CNodeT *> &subgraph_nodes);
STATUS UpdateSubgraphOutput(const size_t &subgraph_index, schema::CNodeT *partial_node,
const std::vector<schema::CNodeT *> &subgraph_nodes);
std::unique_ptr<schema::TensorT> NewTensor(const std::unique_ptr<schema::TensorT> &in_tensor);
void RemoveUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph);
void DoubleIdx(uint32_t *idx);
std::unique_ptr<schema::TensorT> NewTensor(const std::unique_ptr<schema::TensorT> &in_tensor, bool with_data = false);
void IsolateUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph);
void UpdateSwitchOutputIndices(uint32_t *idx);
STATUS BodyGraphVariableInput(std::vector<size_t> *variable_input);

const size_t kSwitchCondIndex = 0;
const size_t kSwitchBodyIndex = 1;
@@ -70,10 +73,7 @@ class SingleSwitchPass {
std::vector<schema::CNodeT *> this_graph_nodes_;
std::vector<schema::CNodeT *> body_graph_nodes_;
std::vector<schema::CNodeT *> cond_graph_nodes_;
std::vector<schema::CNodeT *> switch_users_;
size_t switch_node_index_ = -1;
size_t cond_node_index_ = -1;
size_t body_node_index_ = -1;
int32_t this_subgraph_index_ = -1;
int32_t cond_subgraph_index_ = -1;
int32_t body_subgraph_index_ = -1;


+ 0
- 4
mindspore/lite/tools/optimizer/graph/infershape_pass.cc View File

@@ -346,10 +346,6 @@ STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) {
}

bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
if (func_graph->has_flag("HasInferShaped")) {
return true;
}

if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) {
MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
return false;


Loading…
Cancel
Save