|
|
|
@@ -46,9 +46,6 @@ using uint64 = uint64_t; |
|
|
|
namespace mindspore::lite { |
|
|
|
|
|
|
|
static constexpr char kConstantValueNode[] = "Constant"; |
|
|
|
static constexpr char kCNodeShapeAttr[] = "shape"; |
|
|
|
static constexpr char kCNodeShape1Attr[] = "shape1"; |
|
|
|
static constexpr char kCNodeShape2Attr[] = "shape2"; |
|
|
|
|
|
|
|
enum ParseForm : int { |
|
|
|
FORM_PARSE_TYPE = 0, |
|
|
|
@@ -57,32 +54,143 @@ enum ParseForm : int { |
|
|
|
}; |
|
|
|
|
|
|
|
static std::map<std::string, ParseForm> kParseTypeSwitchMap{ |
|
|
|
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; |
|
|
|
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; |
|
|
|
|
|
|
|
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ |
|
|
|
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, |
|
|
|
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, |
|
|
|
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, |
|
|
|
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, |
|
|
|
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, |
|
|
|
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, |
|
|
|
{onnx::TensorProto_DataType_STRING, kObjectTypeString}, |
|
|
|
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, |
|
|
|
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, |
|
|
|
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, |
|
|
|
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, |
|
|
|
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, |
|
|
|
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, |
|
|
|
{onnx::TensorProto_DataType_STRING, kObjectTypeString}, |
|
|
|
}; |
|
|
|
|
|
|
|
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ |
|
|
|
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ |
|
|
|
const onnx::TensorProto &attr_tensor) { \ |
|
|
|
MS_EXCEPTION_IF_NULL(prim); \ |
|
|
|
std::vector<ValuePtr> attr_value_vec; \ |
|
|
|
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ |
|
|
|
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \ |
|
|
|
attr_value_vec.push_back(MakeValue<valuetype>(value)); \ |
|
|
|
} \ |
|
|
|
if (attr_value_vec.size() == 1) { \ |
|
|
|
prim->AddAttr(attr_name, attr_value_vec[0]); \ |
|
|
|
} else { \ |
|
|
|
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \ |
|
|
|
} \ |
|
|
|
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name, |
|
|
|
const std::unordered_map<string, ValuePtr> &kv) { |
|
|
|
std::string str = attr_name; |
|
|
|
auto replace = [&](const string &orgStr, const string &newStr) { |
|
|
|
std::string::size_type pos(0); |
|
|
|
while ((pos = str.find(orgStr)) != std::string::npos) { |
|
|
|
str.replace(pos, orgStr.length(), newStr); |
|
|
|
} |
|
|
|
return str; |
|
|
|
}; |
|
|
|
// remove "scalar:" |
|
|
|
str = replace("scalar:", ""); |
|
|
|
// remove "Tuple" |
|
|
|
str = replace("Tuple", ""); |
|
|
|
// remove "List" |
|
|
|
str = replace("List", ""); |
|
|
|
std::stack<std::string> rules; |
|
|
|
std::stack<ValuePtr> value; |
|
|
|
int num = 0, count = 0; |
|
|
|
for (size_t i = 0; i < str.length(); i++) { |
|
|
|
if (str[i] == '[') { |
|
|
|
rules.push("["); |
|
|
|
} else if (str[i] == ']') { |
|
|
|
// rules |
|
|
|
std::vector<ValuePtr> vec; |
|
|
|
while (rules.top() != "[") { |
|
|
|
rules.pop(); |
|
|
|
vec.push_back(value.top()); |
|
|
|
value.pop(); |
|
|
|
} |
|
|
|
// pop "[" |
|
|
|
rules.pop(); |
|
|
|
// make tuple for names |
|
|
|
std::string res = "dummy"; |
|
|
|
// make tuple for values |
|
|
|
reverse(vec.begin(), vec.end()); |
|
|
|
auto vt = std::make_shared<ValueTuple>(vec); |
|
|
|
if (rules.empty() && value.empty()) { |
|
|
|
return vt; |
|
|
|
} |
|
|
|
rules.push(res); |
|
|
|
value.push(vt); |
|
|
|
} else if (str[i] == ',') { |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
count++; |
|
|
|
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { |
|
|
|
auto value_name = str.substr(i - count + 1, count); |
|
|
|
value.push(kv.at(value_name)); |
|
|
|
rules.push(value_name); |
|
|
|
count = 0; |
|
|
|
num++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<abstract::AbstractTuple> |
|
|
|
ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, abstract::AbstractTensorPtr> &kv) { |
|
|
|
std::string str = attr_name; |
|
|
|
auto replace = [&](const string &orgStr, const string &newStr) { |
|
|
|
std::string::size_type pos(0); |
|
|
|
while ((pos = str.find(orgStr)) != std::string::npos) { |
|
|
|
str.replace(pos, orgStr.length(), newStr); |
|
|
|
} |
|
|
|
return str; |
|
|
|
}; |
|
|
|
// remove "scalar:" |
|
|
|
str = replace("shape:", ""); |
|
|
|
// remove "Tuple" |
|
|
|
str = replace("Tuple", ""); |
|
|
|
// remove "List" |
|
|
|
str = replace("List", ""); |
|
|
|
std::stack<std::string> rules; |
|
|
|
std::stack<abstract::AbstractBasePtr> value; |
|
|
|
int num = 0, count = 0; |
|
|
|
for (size_t i = 0; i < str.length(); i++) { |
|
|
|
if (str[i] == '[') { |
|
|
|
rules.push("["); |
|
|
|
} else if (str[i] == ']') { |
|
|
|
// rules |
|
|
|
std::vector<abstract::AbstractBasePtr> vec; |
|
|
|
while (rules.top() != "[") { |
|
|
|
rules.pop(); |
|
|
|
vec.push_back(value.top()); |
|
|
|
value.pop(); |
|
|
|
} |
|
|
|
// pop "[" |
|
|
|
rules.pop(); |
|
|
|
// make tuple for names |
|
|
|
std::string res = "dummy"; |
|
|
|
// make tuple for values |
|
|
|
reverse(vec.begin(), vec.end()); |
|
|
|
auto vt = std::make_shared<abstract::AbstractTuple>(vec); |
|
|
|
if (rules.empty() && value.empty()) { |
|
|
|
return vt; |
|
|
|
} |
|
|
|
rules.push(res); |
|
|
|
value.push(vt); |
|
|
|
} else if (str[i] == ',') { |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
count++; |
|
|
|
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { |
|
|
|
auto value_name = str.substr(i - count + 1, count); |
|
|
|
value.push(kv.at(value_name)); |
|
|
|
rules.push(value_name); |
|
|
|
count = 0; |
|
|
|
num++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ |
|
|
|
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ |
|
|
|
if (attr_tensor.type##_data_size() == 1) { \ |
|
|
|
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \ |
|
|
|
return MakeValue<valuetype>(value); \ |
|
|
|
} else { \ |
|
|
|
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ |
|
|
|
} \ |
|
|
|
return {}; \ |
|
|
|
} |
|
|
|
|
|
|
|
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) |
|
|
|
@@ -193,45 +301,34 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, |
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { |
|
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
|
switch (attr_tensor_type) { |
|
|
|
case onnx::TensorProto_DataType_STRING: { |
|
|
|
ParseAttrInScalar_string_string(prim, attr_name, attr_tensor); |
|
|
|
break; |
|
|
|
return ParseAttrInScalar_string_string(attr_tensor); |
|
|
|
} |
|
|
|
case onnx::TensorProto_DataType_INT32: { |
|
|
|
ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor); |
|
|
|
break; |
|
|
|
return ParseAttrInScalar_int32_int32(attr_tensor); |
|
|
|
} |
|
|
|
case onnx::TensorProto_DataType_INT64: { |
|
|
|
ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor); |
|
|
|
break; |
|
|
|
return ParseAttrInScalar_int64_int64(attr_tensor); |
|
|
|
} |
|
|
|
case onnx::TensorProto_DataType_UINT64: { |
|
|
|
ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor); |
|
|
|
break; |
|
|
|
return ParseAttrInScalar_uint64_uint64(attr_tensor); |
|
|
|
} |
|
|
|
case onnx::TensorProto_DataType_FLOAT: { |
|
|
|
ParseAttrInScalar_float_float(prim, attr_name, attr_tensor); |
|
|
|
break; |
|
|
|
return ParseAttrInScalar_float_float(attr_tensor); |
|
|
|
} |
|
|
|
case onnx::TensorProto_DataType_DOUBLE: { |
|
|
|
ParseAttrInScalar_double_double(prim, attr_name, attr_tensor); |
|
|
|
break; |
|
|
|
return ParseAttrInScalar_double_double(attr_tensor); |
|
|
|
} |
|
|
|
case onnx::TensorProto_DataType_BOOL: { |
|
|
|
ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor); |
|
|
|
auto value = prim->GetAttr(attr_name); |
|
|
|
break; |
|
|
|
return ParseAttrInScalar_int32_bool(attr_tensor); |
|
|
|
} |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; |
|
|
|
return false; |
|
|
|
default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; |
|
|
|
return {}; |
|
|
|
} |
|
|
|
return true; |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, |
|
|
|
@@ -268,7 +365,6 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr |
|
|
|
prim->set_attr(attr_name, MakeValue<bool>(attr_value)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return ret == EOK; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -280,22 +376,46 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con |
|
|
|
return false; |
|
|
|
} |
|
|
|
const std::string &ref_attr_name = attr_proto.ref_attr_name(); |
|
|
|
const onnx::TensorProto &attr_tensor = attr_proto.t(); |
|
|
|
switch (kParseTypeSwitchMap[ref_attr_name]) { |
|
|
|
case FORM_PARSE_TYPE: { |
|
|
|
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); |
|
|
|
} |
|
|
|
case FORM_PARSE_SCALAR: { |
|
|
|
return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor); |
|
|
|
string type; |
|
|
|
std::size_t pos(0); |
|
|
|
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { |
|
|
|
type = ref_attr_name.substr(pos, string("scalar:").length() - 1); |
|
|
|
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { |
|
|
|
type = ref_attr_name.substr(pos, string("type:").length() - 1); |
|
|
|
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { |
|
|
|
type = ref_attr_name.substr(pos, string("tensor:").length() - 1); |
|
|
|
} |
|
|
|
std::unordered_map<std::string, ValuePtr> kv; |
|
|
|
for (int i = 0; i < attr_proto.tensors_size(); i++) { |
|
|
|
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); |
|
|
|
switch (kParseTypeSwitchMap[type]) { |
|
|
|
case FORM_PARSE_TYPE: { |
|
|
|
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); |
|
|
|
} |
|
|
|
case FORM_PARSE_SCALAR: { |
|
|
|
auto res = ObtainCNodeAttrInScalarForm(attr_tensor); |
|
|
|
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res)); |
|
|
|
break; |
|
|
|
} |
|
|
|
case FORM_PARSE_TENSOR: { |
|
|
|
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); |
|
|
|
} |
|
|
|
default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
case FORM_PARSE_TENSOR: { |
|
|
|
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); |
|
|
|
} |
|
|
|
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { |
|
|
|
if (kv.size() == 1) { |
|
|
|
std::unordered_map<std::string, ValuePtr>::iterator iter = kv.begin(); |
|
|
|
prim->AddAttr(attr_name, iter->second); |
|
|
|
} else { |
|
|
|
auto res = ParserScalarAttrValue(ref_attr_name, kv); |
|
|
|
prim->AddAttr(attr_name, res); |
|
|
|
} |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, |
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
|
@@ -321,53 +441,6 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name, |
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
|
ValuePtr value_ptr = nullptr; |
|
|
|
switch (attr_tensor_type) { |
|
|
|
case onnx::TensorProto_DataType_INT32: { |
|
|
|
std::vector<int32> add_data; |
|
|
|
for (int i = 0; i < attr_tensor.int32_data_size(); ++i) { |
|
|
|
add_data.push_back(attr_tensor.int32_data(i)); |
|
|
|
} |
|
|
|
if (add_data.size() == 1) { |
|
|
|
value_ptr = MakeValue(add_data[0]); |
|
|
|
} else if (!add_data.empty()) { |
|
|
|
value_ptr = MakeValue<std::vector<int32> >(add_data); |
|
|
|
} |
|
|
|
break; |
|
|
|
} |
|
|
|
case onnx::TensorProto_DataType_FLOAT: { |
|
|
|
std::vector<float> add_data; |
|
|
|
for (int i = 0; i < attr_tensor.float_data_size(); ++i) { |
|
|
|
add_data.push_back(attr_tensor.float_data(i)); |
|
|
|
} |
|
|
|
|
|
|
|
if (add_data.size() == 1) { |
|
|
|
value_ptr = MakeValue(add_data[0]); |
|
|
|
} else if (!add_data.empty()) { |
|
|
|
value_ptr = MakeValue<std::vector<float> >(add_data); |
|
|
|
} |
|
|
|
break; |
|
|
|
} |
|
|
|
case onnx::TensorProto_DataType_UNDEFINED: { |
|
|
|
std::vector<ValuePtr> elems; |
|
|
|
value_ptr = std::make_shared<ValueTuple>(elems); |
|
|
|
break; |
|
|
|
} |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto new_value_node = NewValueNode(value_ptr); |
|
|
|
MS_EXCEPTION_IF_NULL(new_value_node); |
|
|
|
new_value_node->set_abstract(value_ptr->ToAbstract()); |
|
|
|
anfnode_build_map_[value_node_name] = new_value_node; |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, |
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
|
@@ -382,23 +455,56 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name, |
|
|
|
const std::string &value_node_name, |
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
|
switch (kParseTypeSwitchMap[ref_attr_name]) { |
|
|
|
case FORM_PARSE_SCALAR: { |
|
|
|
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); |
|
|
|
} |
|
|
|
case FORM_PARSE_TENSOR: { |
|
|
|
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); |
|
|
|
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name, |
|
|
|
const onnx::AttributeProto &attr_proto) { |
|
|
|
if (!attr_proto.has_ref_attr_name()) { |
|
|
|
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
const std::string &ref_attr_name = attr_proto.ref_attr_name(); |
|
|
|
string type; |
|
|
|
std::size_t pos(0); |
|
|
|
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { |
|
|
|
type = ref_attr_name.substr(pos, string("scalar:").length() - 1); |
|
|
|
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { |
|
|
|
type = ref_attr_name.substr(pos, string("type:").length() - 1); |
|
|
|
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { |
|
|
|
type = ref_attr_name.substr(pos, string("tensor:").length() - 1); |
|
|
|
} |
|
|
|
std::unordered_map<std::string, ValuePtr> kv; |
|
|
|
for (int i = 0; i < attr_proto.tensors_size(); i++) { |
|
|
|
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); |
|
|
|
switch (kParseTypeSwitchMap[type]) { |
|
|
|
case FORM_PARSE_TYPE: { |
|
|
|
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); |
|
|
|
} |
|
|
|
case FORM_PARSE_SCALAR: { |
|
|
|
auto res = ObtainCNodeAttrInScalarForm(attr_tensor); |
|
|
|
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res)); |
|
|
|
break; |
|
|
|
} |
|
|
|
case FORM_PARSE_TENSOR: { |
|
|
|
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); |
|
|
|
} |
|
|
|
default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
case FORM_PARSE_TYPE: { |
|
|
|
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr new_value_node; |
|
|
|
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { |
|
|
|
if (kv.size() == 1) { |
|
|
|
auto iter = kv.begin(); |
|
|
|
new_value_node = NewValueNode(iter->second); |
|
|
|
new_value_node->set_abstract(iter->second->ToAbstract()); |
|
|
|
} else { |
|
|
|
auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv); |
|
|
|
new_value_node = NewValueNode(value_ptr); |
|
|
|
new_value_node->set_abstract(value_ptr->ToAbstract()); |
|
|
|
} |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; |
|
|
|
return false; |
|
|
|
anfnode_build_map_[value_node_name] = new_value_node; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { |
|
|
|
@@ -408,22 +514,23 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto & |
|
|
|
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
const std::string &ref_attr_name = attr_proto.ref_attr_name(); |
|
|
|
const onnx::TensorProto &attr_tensor = attr_proto.t(); |
|
|
|
|
|
|
|
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); |
|
|
|
return GetAttrValueForValueNode(value_node_name, attr_proto); |
|
|
|
} |
|
|
|
|
|
|
|
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { |
|
|
|
std::vector<int> shape_vec; |
|
|
|
const onnx::TensorProto &attr_tensor = attr_proto.t(); |
|
|
|
for (int i = 0; i < attr_tensor.dims_size(); ++i) { |
|
|
|
shape_vec.push_back(attr_tensor.dims(i)); |
|
|
|
std::unordered_map<std::string, abstract::AbstractTensorPtr> |
|
|
|
AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { |
|
|
|
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv; |
|
|
|
for (int i = 0; i < attr_proto.tensors_size(); i++) { |
|
|
|
std::vector<int> shape_vec; |
|
|
|
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); |
|
|
|
for (int j = 0; j < attr_tensor.dims_size(); ++j) { |
|
|
|
shape_vec.push_back(attr_tensor.dims(j)); |
|
|
|
} |
|
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); |
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec); |
|
|
|
kv.insert(std::pair<string, abstract::AbstractTensorPtr>(attr_tensor.name(), abstract_tensor)); |
|
|
|
} |
|
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); |
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec); |
|
|
|
MS_EXCEPTION_IF_NULL(abstract_tensor); |
|
|
|
return abstract_tensor; |
|
|
|
return kv; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, |
|
|
|
@@ -437,25 +544,16 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out |
|
|
|
const std::string &node_name = node_proto.output(0); |
|
|
|
const std::string &fullname_with_scope = node_proto.domain(); |
|
|
|
const std::string &node_type = node_proto.op_type(); |
|
|
|
PrimitivePtr prim = std::make_shared<Primitive>(node_type); |
|
|
|
PrimitivePtr prim = std::make_shared<mindspore::Primitive>(node_type); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
prim->set_instance_name(node_type); |
|
|
|
|
|
|
|
abstract::AbstractTensorPtr abstract = nullptr; |
|
|
|
abstract::AbstractTensorPtr abstract_first = nullptr; |
|
|
|
abstract::AbstractTensorPtr abstract_second = nullptr; |
|
|
|
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv; |
|
|
|
string shape_ref_attr_name; |
|
|
|
for (int i = 0; i < node_proto.attribute_size(); ++i) { |
|
|
|
const onnx::AttributeProto &attr_proto = node_proto.attribute(i); |
|
|
|
if (attr_proto.name() == kCNodeShapeAttr) { |
|
|
|
abstract = GetAbstractForCNode(attr_proto); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (attr_proto.name() == kCNodeShape1Attr) { |
|
|
|
abstract_first = GetAbstractForCNode(attr_proto); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (attr_proto.name() == kCNodeShape2Attr) { |
|
|
|
abstract_second = GetAbstractForCNode(attr_proto); |
|
|
|
if (attr_proto.ref_attr_name().find("shape:") != string::npos) { |
|
|
|
shape_ref_attr_name = attr_proto.ref_attr_name(); |
|
|
|
kv = GetAbstractForCNode(attr_proto); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!GetAttrValueForCNode(prim, attr_proto)) { |
|
|
|
@@ -463,6 +561,7 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.clear(); |
|
|
|
for (int i = 0; i < node_proto.input_size(); ++i) { |
|
|
|
@@ -481,26 +580,20 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out |
|
|
|
inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr)); |
|
|
|
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode_ptr); |
|
|
|
if (node_type == "LayerNorm") { |
|
|
|
AbstractBasePtrList elem; |
|
|
|
elem.push_back(abstract); |
|
|
|
elem.push_back(abstract_first); |
|
|
|
elem.push_back(abstract_second); |
|
|
|
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); |
|
|
|
} else if (node_type == "ArgMaxWithValue") { |
|
|
|
AbstractBasePtrList elem; |
|
|
|
elem.push_back(abstract); |
|
|
|
elem.push_back(abstract_first); |
|
|
|
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); |
|
|
|
} else if (nullptr == abstract) { |
|
|
|
if (0 == kv.size()) { |
|
|
|
AbstractBasePtrList elem; |
|
|
|
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { |
|
|
|
elem.push_back(cnode_ptr->input(index)->abstract()); |
|
|
|
} |
|
|
|
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); |
|
|
|
} else if (1 == kv.size()) { |
|
|
|
std::unordered_map<std::string, abstract::AbstractTensorPtr>::iterator iter = kv.begin(); |
|
|
|
cnode_ptr->set_abstract(iter->second); |
|
|
|
} else { |
|
|
|
auto abstract = ParserAttrShape(shape_ref_attr_name, kv); |
|
|
|
cnode_ptr->set_abstract(abstract); |
|
|
|
} |
|
|
|
|
|
|
|
cnode_ptr->set_fullname_with_scope(fullname_with_scope); |
|
|
|
anfnode_build_map_[node_name] = cnode_ptr; |
|
|
|
return cnode_ptr; |
|
|
|
@@ -652,7 +745,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { |
|
|
|
|
|
|
|
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { |
|
|
|
auto onnx_model = new onnx::ModelProto; |
|
|
|
if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) { |
|
|
|
if (ReadProtoFromBinaryFile((const char *) model_path.c_str(), onnx_model) != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|