|
|
|
@@ -111,8 +111,8 @@ class OpInfoExtractor { |
|
|
|
op_attr->set_type("bool"); |
|
|
|
} else if (v->isa<StringImm>()) { |
|
|
|
op_attr->set_type("str"); |
|
|
|
} else if (v->isa<ValueList>() || v->isa<ValueTuple>()) { |
|
|
|
auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value(); |
|
|
|
} else if (v->isa<ValueSequeue>()) { |
|
|
|
const auto &vec = v->cast<ValueSequeuePtr>()->value(); |
|
|
|
if (vec.empty()) { |
|
|
|
op_attr->set_type("listInt"); |
|
|
|
} else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) { |
|
|
|
@@ -262,10 +262,14 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std:: |
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
MS_EXCEPTION_IF_NULL(op_attr); |
|
|
|
MS_EXCEPTION_IF_NULL(attr_json); |
|
|
|
|
|
|
|
auto get_int_value = [](const ValuePtr &value) -> int { |
|
|
|
return value->isa<Int64Imm>() ? static_cast<int>(GetValue<int64_t>(value)) : GetValue<int>(value); |
|
|
|
}; |
|
|
|
std::string type = op_attr->type(); |
|
|
|
(*attr_json)[kJsonKeyDataType] = type; |
|
|
|
if (type == "int") { |
|
|
|
(*attr_json)[kJsonKeyValue] = static_cast<int>(GetValue<int64_t>(attr_value)); |
|
|
|
(*attr_json)[kJsonKeyValue] = get_int_value(attr_value); |
|
|
|
} else if (type == "str") { |
|
|
|
(*attr_json)[kJsonKeyValue] = GetValue<std::string>(attr_value); |
|
|
|
} else if (type == "bool") { |
|
|
|
@@ -274,9 +278,8 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std:: |
|
|
|
(*attr_json)[kJsonKeyValue] = GetValue<float>(attr_value); |
|
|
|
} else if (type == "listInt") { |
|
|
|
std::vector<int> list_int; |
|
|
|
std::vector<int64_t> list_int_me = GetValue<std::vector<int64_t>>(attr_value); |
|
|
|
(void)std::transform(list_int_me.begin(), list_int_me.end(), std::back_inserter(list_int), |
|
|
|
[](const int64_t &value) { return static_cast<int>(value); }); |
|
|
|
const auto &vals = attr_value->cast<ValueSequeuePtr>()->value(); |
|
|
|
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), get_int_value); |
|
|
|
(*attr_json)[kJsonKeyValue] = list_int; |
|
|
|
} else if (type == "listStr") { |
|
|
|
std::vector<std::string> data_format; |
|
|
|
|