|
|
|
@@ -365,7 +365,7 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto |
|
|
|
// 1. prim ArgMaxWithValue need to get shape from tuple element |
|
|
|
// 2. some cnode doesn't has shape, such as LayerNorm |
|
|
|
// 3. other cnodes have shape |
|
|
|
if (node->IsApply(prim::kPrimArgMaxWithValue)) { |
|
|
|
if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { |
|
|
|
auto type = node->Type(); |
|
|
|
auto shape = node->Shape(); |
|
|
|
if (!type->isa<Tuple>()) { |
|
|
|
@@ -407,6 +407,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g |
|
|
|
std::string output_name = GetUniqueNodeName(node); |
|
|
|
node_proto->add_output(output_name); |
|
|
|
node_proto->set_name(output_name); |
|
|
|
node_proto->set_domain(node->fullname_with_scope()); |
|
|
|
AnfNodePtr op = node->input(0); |
|
|
|
std::string type_name = GetOpTypeName(op); |
|
|
|
node_proto->set_op_type(type_name); |
|
|
|
|