|
|
|
@@ -1,6 +1,4 @@ |
|
|
|
/** |
|
|
|
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). |
|
|
|
* |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
@@ -17,52 +15,114 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "src/common/anf_exporter/anf_exporter.h" |
|
|
|
|
|
|
|
#include <memory> |
|
|
|
#include <set> |
|
|
|
#include <string> |
|
|
|
#include <utility> |
|
|
|
#include <vector> |
|
|
|
#include <string> |
|
|
|
|
|
|
|
#include "abstract/abstract_value.h" |
|
|
|
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" |
|
|
|
#include "src/param_value_lite.h" |
|
|
|
#include "base/core_ops.h" |
|
|
|
#include "mindspore/core/ir/primitive.h" |
|
|
|
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" |
|
|
|
#include "src/ir/primitive_t_value.h" |
|
|
|
#include "base/core_ops.h" |
|
|
|
#include "src/ir/tensor.h" |
|
|
|
#include "src/param_value_lite.h" |
|
|
|
|
|
|
|
namespace mindspore::lite { |
|
|
|
std::set<std::string> RemoveNodeInAnfExporter{"tuple_getitem", "make_tuple"}; |
|
|
|
|
|
|
|
void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { |
|
|
|
bool hasMakeTuple = false; |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.clear(); |
|
|
|
|
|
|
|
inputs.emplace_back(cnode->input(0)); |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) { |
|
|
|
AnfNodePtr inputNode = cnode->input(i); |
|
|
|
if (!inputNode->isa<CNode>()) { |
|
|
|
inputs.emplace_back(cnode->input(i)); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto makeTupleNode = utils::cast<CNodePtr>(inputNode); |
|
|
|
if (IsPrimitiveCNode(makeTupleNode, prim::kPrimMakeTuple)) { |
|
|
|
hasMakeTuple = true; |
|
|
|
for (size_t j = 1; j < makeTupleNode->inputs().size(); ++j) { |
|
|
|
inputs.emplace_back(makeTupleNode->input(j)); |
|
|
|
} |
|
|
|
} else { |
|
|
|
inputs.emplace_back(cnode->input(i)); |
|
|
|
} |
|
|
|
} |
|
|
|
if (hasMakeTuple) { |
|
|
|
cnode->set_inputs(inputs); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { |
|
|
|
bool hasTupleGetItem = false; |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.clear(); |
|
|
|
inputs.emplace_back(cnode->input(0)); |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) { |
|
|
|
AnfNodePtr inputNode = cnode->input(i); |
|
|
|
if (!inputNode->isa<CNode>()) { |
|
|
|
inputs.emplace_back(cnode->input(i)); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto tupleGetItemNode = utils::cast<CNodePtr>(inputNode); |
|
|
|
if (IsPrimitiveCNode(tupleGetItemNode, prim::kPrimTupleGetItem)) { |
|
|
|
hasTupleGetItem = true; |
|
|
|
inputs.emplace_back(tupleGetItemNode->input(1)); |
|
|
|
AnfNodePtr indexNode = tupleGetItemNode->input(2); |
|
|
|
if (utils::isa<ValueNodePtr>(indexNode)) { |
|
|
|
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
ValueNodePtr valueNode = utils::cast<ValueNodePtr>(indexNode); |
|
|
|
mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] = |
|
|
|
GetValue<int>(valueNode->value()); |
|
|
|
} else { |
|
|
|
inputs.emplace_back(cnode->input(i)); |
|
|
|
} |
|
|
|
} |
|
|
|
if (hasTupleGetItem) { |
|
|
|
cnode->set_inputs(inputs); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &metaGraphT, const CNodePtr &cnode) { |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) { |
|
|
|
auto inputNode = cnode->input(i); |
|
|
|
if (!inputNode->isa<CNode>()) { |
|
|
|
MS_LOG(ERROR) << "Node of Return's input is not CNode"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto inputCNode = utils::cast<CNodePtr>(inputNode); |
|
|
|
auto inputPrimitive = GetValueNode<PrimitivePtr>(inputCNode->input(0)); |
|
|
|
std::string inputName = inputNode->fullname_with_scope(); |
|
|
|
auto graphOutput = nodeIdMap[inputName]; |
|
|
|
metaGraphT->outputIndex.emplace_back(graphOutput); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { |
|
|
|
auto cnodes = funcGraph->GetOrderedCnodes(); |
|
|
|
auto metaGraphT = std::make_unique<schema::MetaGraphT>(); |
|
|
|
for (const auto &cnode : cnodes) { |
|
|
|
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
|
if (primitive != nullptr && primitive == prim::kPrimReturn) { |
|
|
|
// set graph outputs tensors |
|
|
|
auto inputNode = cnode->input(1); |
|
|
|
if (!inputNode->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto inputCNode = utils::cast<CNodePtr>(inputNode); |
|
|
|
auto inputPrimitive = GetValueNode<PrimitivePtr>(inputCNode->input(0)); |
|
|
|
if (inputPrimitive == prim::kPrimMakeTuple) { |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
std::string inputName = inputNode->fullname_with_scope(); |
|
|
|
auto graphOutput = nodeIdMap[inputName]; |
|
|
|
metaGraphT->outputIndex.emplace_back(graphOutput); |
|
|
|
} |
|
|
|
if (primitive != nullptr && |
|
|
|
RemoveNodeInAnfExporter.count(primitive->name()) != 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (primitive != nullptr && primitive == prim::kPrimMakeTuple) { |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) { |
|
|
|
auto graphOutNode = cnode->input(i); |
|
|
|
if (!graphOutNode->isa<CNode>()) { |
|
|
|
MS_LOG(ERROR) << "Inputs of MakeTuple should be cNode"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::string graphOutNodeName = graphOutNode->fullname_with_scope(); |
|
|
|
auto graphOutIndex = nodeIdMap[graphOutNodeName]; |
|
|
|
metaGraphT->outputIndex.emplace_back(graphOutIndex); |
|
|
|
} |
|
|
|
mapRemoveGetItem_.clear(); |
|
|
|
RemoveIfMakeTuple(cnode); |
|
|
|
RemoveIfTupleGetItem(cnode); |
|
|
|
if (primitive != nullptr && primitive->name() == prim::kPrimReturn->name()) { |
|
|
|
AddOutPutIfReturn(metaGraphT, cnode); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -74,19 +134,27 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { |
|
|
|
primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
|
MS_ASSERT(primitive != nullptr); |
|
|
|
std::string opType = primitive->name(); |
|
|
|
auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); |
|
|
|
auto nodeParser = |
|
|
|
AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); |
|
|
|
if (nodeParser == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::vector<schema::TensorT *> outputs; |
|
|
|
if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) { |
|
|
|
auto abstract_cnode = |
|
|
|
utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract()); |
|
|
|
outputs.resize(abstract_cnode->size()); |
|
|
|
} |
|
|
|
|
|
|
|
nodeParser->Parse(cnode, node.get(), &outputs); |
|
|
|
SetOpInputNode(cnode, metaGraphT.get(), node.get()); |
|
|
|
SetOpOutputNode(outputs, metaGraphT.get(), node.get()); |
|
|
|
metaGraphT->nodes.emplace_back(std::move(node)); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); |
|
|
|
auto primitiveT_value = |
|
|
|
GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); |
|
|
|
if (primitiveT_value == nullptr) { |
|
|
|
MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; |
|
|
|
return nullptr; |
|
|
|
@@ -98,7 +166,8 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT()); |
|
|
|
node->primitive = |
|
|
|
std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT()); |
|
|
|
std::vector<schema::TensorT *> outputs; |
|
|
|
SetOpInputNode(cnode, metaGraphT.get(), node.get()); |
|
|
|
SetOpOutputNode(outputs, metaGraphT.get(), node.get()); |
|
|
|
@@ -112,10 +181,11 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { |
|
|
|
auto tensor_input = metaGraphT->allTensors[activate_index].get(); |
|
|
|
auto input_quant_params = primitiveT_value->GetInputQuantParams(); |
|
|
|
if (input_quant_params.empty()) { |
|
|
|
MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty"; |
|
|
|
MS_LOG(WARNING) << "node: " << node->name |
|
|
|
<< " input quant params is empty"; |
|
|
|
} else { |
|
|
|
std::unique_ptr<schema::QuantParamT> input_quant_param = |
|
|
|
std::make_unique<schema::QuantParamT>(input_quant_params[0]); |
|
|
|
std::make_unique<schema::QuantParamT>(input_quant_params[0]); |
|
|
|
tensor_input->quantParams.emplace_back(std::move(input_quant_param)); |
|
|
|
} |
|
|
|
tensor_input->dataType = kNumberTypeInt8; |
|
|
|
@@ -124,18 +194,20 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { |
|
|
|
auto tensor_output = metaGraphT->allTensors[output_index].get(); |
|
|
|
auto output_quant_params = primitiveT_value->GetOutputQuantParams(); |
|
|
|
if (output_quant_params.empty()) { |
|
|
|
MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; |
|
|
|
MS_LOG(WARNING) << "node: " << node->name |
|
|
|
<< " output quant params is empty"; |
|
|
|
} else { |
|
|
|
std::unique_ptr<schema::QuantParamT> output_quant_param = |
|
|
|
std::make_unique<schema::QuantParamT>(output_quant_params[0]); |
|
|
|
std::make_unique<schema::QuantParamT>(output_quant_params[0]); |
|
|
|
tensor_output->quantParams.emplace_back(std::move(output_quant_param)); |
|
|
|
} |
|
|
|
tensor_output->dataType = kNumberTypeInt8; |
|
|
|
// // TensorType |
|
|
|
// valuePtr = primitive->GetAttr(kInputTensorDataType); |
|
|
|
// if (valuePtr != nullptr) { |
|
|
|
// MS_LOG(INFO) << "node: " << node->name << " input tensor data type: " << GetValue<int>(valuePtr); |
|
|
|
// for (auto input : node->inputIndex) { |
|
|
|
// MS_LOG(INFO) << "node: " << node->name << " input tensor data |
|
|
|
// type: " << GetValue<int>(valuePtr); for (auto input : |
|
|
|
// node->inputIndex) { |
|
|
|
// auto tensor = subGraph->allTensors[input].get(); |
|
|
|
// tensor->dataType = kNumberTypeUInt8; |
|
|
|
// } |
|
|
|
@@ -159,7 +231,9 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { |
|
|
|
return metaGraphT.release(); |
|
|
|
} |
|
|
|
|
|
|
|
void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode) { |
|
|
|
void AnfExporter::SetOpInputNode(const CNodePtr &cnode, |
|
|
|
schema::MetaGraphT *meta_graph, |
|
|
|
schema::CNodeT *fbNode) { |
|
|
|
MS_ASSERT(nullptr != meta_graph); |
|
|
|
MS_ASSERT(nullptr != fbNode); |
|
|
|
if (cnode->inputs().size() <= 1) { |
|
|
|
@@ -172,6 +246,13 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta |
|
|
|
if (inputNode->isa<CNode>()) { |
|
|
|
isGraphInput = false; |
|
|
|
std::string inputName = inputNode->fullname_with_scope(); |
|
|
|
if (!mapRemoveGetItem_.empty()) { |
|
|
|
for (auto name : mapRemoveGetItem_) { |
|
|
|
if (name.first == inputName) { |
|
|
|
inputName = inputName + "_o:" + std::to_string(name.second); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (nodeIdMap.find(inputName) != nodeIdMap.end()) { |
|
|
|
fbNode->inputIndex.emplace_back(nodeIdMap[inputName]); |
|
|
|
} |
|
|
|
@@ -187,30 +268,38 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta |
|
|
|
auto paramTensor = std::make_unique<schema::TensorT>(); |
|
|
|
auto abstractBase = paramNode->abstract(); |
|
|
|
if (abstractBase == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); |
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " |
|
|
|
<< paramNode->name(); |
|
|
|
MS_ASSERT(false); |
|
|
|
return; |
|
|
|
} |
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { |
|
|
|
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name(); |
|
|
|
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " |
|
|
|
<< paramNode->name(); |
|
|
|
MS_ASSERT(false); |
|
|
|
return; |
|
|
|
} |
|
|
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); |
|
|
|
auto abstractTensor = |
|
|
|
utils::cast<abstract::AbstractTensorPtr>(abstractBase); |
|
|
|
auto typePtr = abstractTensor->element()->GetTypeTrack(); |
|
|
|
MS_ASSERT(typePtr != nullptr); |
|
|
|
paramTensor->dataType = typePtr->type_id(); |
|
|
|
if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { |
|
|
|
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); |
|
|
|
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " |
|
|
|
<< paramNode->name(); |
|
|
|
MS_ASSERT(false); |
|
|
|
return; |
|
|
|
} |
|
|
|
paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); |
|
|
|
auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param()); |
|
|
|
paramTensor->dims = |
|
|
|
utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape()) |
|
|
|
->shape(); |
|
|
|
auto paramValue = |
|
|
|
std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param()); |
|
|
|
if (paramValue != nullptr) { |
|
|
|
paramTensor->nodeType = schema::NodeType_ValueNode; |
|
|
|
paramTensor->data.resize(paramValue->tensor_size()); |
|
|
|
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); |
|
|
|
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), |
|
|
|
paramValue->tensor_size()); |
|
|
|
} |
|
|
|
for (auto &ite : paramValue->quant_param()) { |
|
|
|
auto quantPar = std::make_unique<schema::QuantParamT>(); |
|
|
|
@@ -224,7 +313,8 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta |
|
|
|
paramTensor->quantParams.emplace_back(std::move(quantPar)); |
|
|
|
paramTensor->dataType = paramValue->tensor_type(); |
|
|
|
} |
|
|
|
nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); |
|
|
|
nodeIdMap[paramNode->fullname_with_scope()] = |
|
|
|
meta_graph->allTensors.size(); |
|
|
|
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); |
|
|
|
meta_graph->allTensors.emplace_back(std::move(paramTensor)); |
|
|
|
} else if (inputNode->isa<ValueNode>()) { |
|
|
|
@@ -233,15 +323,19 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta |
|
|
|
auto value = valueNode->value(); |
|
|
|
if (value->isa<lite::tensor::Tensor>()) { |
|
|
|
auto valueAbstract = valueNode->abstract(); |
|
|
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract); |
|
|
|
auto abstractTensor = |
|
|
|
utils::cast<abstract::AbstractTensorPtr>(valueAbstract); |
|
|
|
auto typePtr = abstractTensor->element()->GetTypeTrack(); |
|
|
|
paramTensor->dataType = typePtr->type_id(); |
|
|
|
paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); |
|
|
|
paramTensor->dims = |
|
|
|
utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape()) |
|
|
|
->shape(); |
|
|
|
paramTensor->nodeType = schema::NodeType_ValueNode; |
|
|
|
auto data = value->cast<lite::tensor::TensorPtr>(); |
|
|
|
paramTensor->data.resize(data->Size()); |
|
|
|
memcpy(paramTensor->data.data(), data->Data(), data->Size()); |
|
|
|
nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); |
|
|
|
nodeIdMap[valueNode->fullname_with_scope()] = |
|
|
|
meta_graph->allTensors.size(); |
|
|
|
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); |
|
|
|
meta_graph->allTensors.emplace_back(std::move(paramTensor)); |
|
|
|
} else if (value->isa<mindspore::ValueSequeue>()) { |
|
|
|
@@ -257,8 +351,9 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AnfExporter::SetOpOutputNode(const std::vector<schema::TensorT *> &outputTensors, schema::MetaGraphT *graph, |
|
|
|
schema::CNodeT *cnode) { |
|
|
|
void AnfExporter::SetOpOutputNode( |
|
|
|
const std::vector<schema::TensorT *> &outputTensors, |
|
|
|
schema::MetaGraphT *graph, schema::CNodeT *cnode) { |
|
|
|
MS_ASSERT(nullptr != graph); |
|
|
|
MS_ASSERT(nullptr != cnode); |
|
|
|
std::string cnodeName = cnode->name; |
|
|
|
|