|
|
|
@@ -91,9 +91,9 @@ class IrExportBuilder { |
|
|
|
void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto); |
|
|
|
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); |
|
|
|
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); |
|
|
|
void SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs, |
|
|
|
onnx::NodeProto *const node_proto); |
|
|
|
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto); |
|
|
|
void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); |
|
|
|
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, |
|
|
|
std::string suffix = "0"); |
|
|
|
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); |
|
|
|
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); |
|
|
|
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); |
|
|
|
@@ -112,10 +112,10 @@ class IrExportBuilder { |
|
|
|
|
|
|
|
private: |
|
|
|
onnx::ModelProto model_; |
|
|
|
onnx::NodeProto *last_node_; |
|
|
|
onnx::NodeProto *last_node_{nullptr}; |
|
|
|
std::list<FuncGraphPtr> todo_; |
|
|
|
std::map<AnfNodePtr, size_t> node_index_map_; |
|
|
|
size_t node_index_ = 0; |
|
|
|
size_t node_index_{0}; |
|
|
|
}; |
|
|
|
|
|
|
|
using IrExporterPtr = std::shared_ptr<IrExporter>; |
|
|
|
@@ -349,44 +349,34 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, |
|
|
|
onnx::NodeProto *const node_proto) { |
|
|
|
onnx::NodeProto *const node_proto, std::string suffix) { |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_ref_attr_name("shape"); |
|
|
|
attr_proto->set_name("shape"); |
|
|
|
if (suffix.compare("0") != 0) { |
|
|
|
attr_proto->set_name("shape" + suffix); |
|
|
|
} else { |
|
|
|
attr_proto->set_name("shape"); |
|
|
|
} |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); |
|
|
|
SetTensorProto(type, shape, tensor_proto); |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs, |
|
|
|
onnx::NodeProto *const node_proto) { |
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { |
|
|
|
// Get shape of cnode |
|
|
|
// 1. prim kPrimTupleGetItem need to get shape of input node according to the index |
|
|
|
// 1. prim ArgMaxWithValue need to get shape from tuple element |
|
|
|
// 2. some cnode doesn't has shape, such as LayerNorm |
|
|
|
// 3. other cnodes have shape |
|
|
|
if (node->IsApply(prim::kPrimTupleGetItem)) { |
|
|
|
// Get index of tuple get_item |
|
|
|
int index_pos = inputs.size() - 1; |
|
|
|
if (!inputs[index_pos]->isa<ValueNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Index is not ValueNode: " << index_pos; |
|
|
|
} |
|
|
|
auto value = inputs[index_pos]->cast<ValueNodePtr>()->value(); |
|
|
|
if (!value->isa<IntergerImm>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Index type is not supported: " << value->type_name(); |
|
|
|
} |
|
|
|
size_t index = GetValue<int>(value); |
|
|
|
|
|
|
|
// Get type and shape of input node |
|
|
|
auto tup_type = inputs[0]->Type(); |
|
|
|
if (!tup_type->isa<Tuple>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Input data of kPrimTupleGetItem cnode must be tuple: " << tup_type->type_name(); |
|
|
|
if (node->IsApply(prim::kPrimArgMaxWithValue)) { |
|
|
|
auto type = node->Type(); |
|
|
|
auto shape = node->Shape(); |
|
|
|
if (!type->isa<Tuple>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); |
|
|
|
} |
|
|
|
auto type = tup_type->cast<TuplePtr>()->elements()[index]; |
|
|
|
auto tup_shape = inputs[0]->Shape()->cast<abstract::TupleShapePtr>(); |
|
|
|
if (index >= tup_shape->shape().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Index exceed upper limit: " << tup_shape->shape().size(); |
|
|
|
auto elements = type->cast<TuplePtr>()->elements(); |
|
|
|
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape(); |
|
|
|
for (size_t i = 0; i < elements.size(); i++) { |
|
|
|
SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); |
|
|
|
} |
|
|
|
auto shape = tup_shape->shape()[index]; |
|
|
|
SetShapeToNodeProto(type, shape, node_proto); |
|
|
|
} else { |
|
|
|
auto type = node->Type(); |
|
|
|
auto shape = node->Shape(); |
|
|
|
@@ -422,7 +412,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g |
|
|
|
std::string type_name = GetOpTypeName(op); |
|
|
|
node_proto->set_op_type(type_name); |
|
|
|
last_node_ = node_proto; |
|
|
|
SetShapeToNodeProto(node, op_inputs, node_proto); |
|
|
|
SetShapeToNodeProto(node, node_proto); |
|
|
|
(void)std::for_each(input_names.begin(), input_names.end(), |
|
|
|
[&node_proto](const string &name) { node_proto->add_input(name); }); |
|
|
|
|
|
|
|
|