|
|
|
@@ -107,10 +107,11 @@ class IrExportBuilder { |
|
|
|
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); |
|
|
|
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto); |
|
|
|
bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto); |
|
|
|
bool SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs, mind_ir::TensorProto *const tensor_proto); |
|
|
|
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto); |
|
|
|
bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto); |
|
|
|
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string); |
|
|
|
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const abstract::AbstractBasePtr &abstract, |
|
|
|
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string); |
|
|
|
bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
@@ -349,6 +350,9 @@ bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn |
|
|
|
tensor_proto->add_dims(dim); |
|
|
|
} |
|
|
|
} |
|
|
|
if (!SetTensorProtoForRef(type, node->abstract(), tensor_proto)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} else if (type->isa<Tuple>()) { |
|
|
|
auto tup_shape = shape->cast<abstract::TupleShapePtr>(); |
|
|
|
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); |
|
|
|
@@ -402,6 +406,25 @@ bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &sh |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool IrExportBuilder::SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs, |
|
|
|
mind_ir::TensorProto *const tensor_proto) { |
|
|
|
if (!type->isa<RefType>()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
auto abs_ref = abs->cast<abstract::AbstractRefPtr>(); |
|
|
|
if (abs_ref == nullptr) { |
|
|
|
MS_LOG(ERROR) << "The abstract " << abs->ToString() << " should be AbstractRef."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto ref_key_value = abs_ref->ref_key_value(); |
|
|
|
if (ref_key_value == nullptr) { |
|
|
|
MS_LOG(ERROR) << "The ref_key_value of abstract ref " << abs->ToString() << " is nullptr"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_ref_key(ref_key_value->name()); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) { |
|
|
|
if (param == nullptr || tensor_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; |
|
|
|
@@ -479,7 +502,7 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { |
|
|
|
return type_name; |
|
|
|
} |
|
|
|
|
|
|
|
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, |
|
|
|
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const AbstractBasePtr &abs, |
|
|
|
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { |
|
|
|
MS_EXCEPTION_IF_NULL(type); |
|
|
|
MS_EXCEPTION_IF_NULL(shape); |
|
|
|
@@ -488,8 +511,9 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt |
|
|
|
*seq_string += "Tuple["; |
|
|
|
auto elements = type->cast<TuplePtr>()->elements(); |
|
|
|
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape(); |
|
|
|
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>()->elements(); |
|
|
|
for (size_t i = 0; i < elements.size(); i++) { |
|
|
|
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string)) { |
|
|
|
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], tuple_abs[i], attr_proto, seq_string)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -499,7 +523,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt |
|
|
|
*seq_string += shape_name + ","; |
|
|
|
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); |
|
|
|
tensor_proto->set_name(shape_name); |
|
|
|
return SetTensorProto(type, shape, tensor_proto); |
|
|
|
return SetTensorProto(type, shape, tensor_proto) && SetTensorProtoForRef(type, abs, tensor_proto); |
|
|
|
} else if (type->isa<Number>()) { |
|
|
|
if (type->isa<Bool>()) { |
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); |
|
|
|
@@ -531,6 +555,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto type = node->Type(); |
|
|
|
auto shape = node->Shape(); |
|
|
|
auto abs = node->abstract(); |
|
|
|
// For the bprop fg which has not been renormalized. |
|
|
|
if (type == nullptr || shape == nullptr) { |
|
|
|
return true; |
|
|
|
@@ -538,7 +563,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro |
|
|
|
ResetTupleIndex(); |
|
|
|
std::string seq_string = "shape:"; |
|
|
|
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
if (!SetShapeToNodeProto(type, shape, attr_proto, &seq_string)) { |
|
|
|
if (!SetShapeToNodeProto(type, shape, abs, attr_proto, &seq_string)) { |
|
|
|
MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|