|
|
|
@@ -139,6 +139,47 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { |
|
|
|
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(node->func_graph()); |
|
|
|
|
|
|
|
// Inputs should be [dict_setitem, dict, item, value] |
|
|
|
const auto &inputs = node->inputs(); |
|
|
|
MS_ASSERT(inputs.size() == 4 && "DictSetItem should have three inputs."); |
|
|
|
|
|
|
|
AnfNodePtr data = inputs[1]; |
|
|
|
AnfNodePtr cons = inputs[2]; |
|
|
|
AnfNodePtr item_value = inputs[3]; |
|
|
|
MS_EXCEPTION_IF_NULL(data); |
|
|
|
MS_EXCEPTION_IF_NULL(cons); |
|
|
|
|
|
|
|
auto dt = data->abstract(); |
|
|
|
MS_EXCEPTION_IF_NULL(dt); |
|
|
|
if (!dt->isa<abstract::AbstractDictionary>()) { |
|
|
|
MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name(); |
|
|
|
} |
|
|
|
auto cons_is_str = IsValueNode<StringImm>(cons); |
|
|
|
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : ""; |
|
|
|
|
|
|
|
auto ct = dyn_cast<abstract::AbstractDictionary>(dt); |
|
|
|
const auto &cmap = ct->elements(); |
|
|
|
int count = 0; |
|
|
|
for (auto &item : cmap) { |
|
|
|
if (cons_is_str && item.first == cons_str) { |
|
|
|
break; |
|
|
|
} |
|
|
|
count++; |
|
|
|
} |
|
|
|
if (IntToSize(count) >= cmap.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "dictionary assignment key " << cons_str |
|
|
|
<< " does not exist, can not create new dictionary item for now."; |
|
|
|
} |
|
|
|
auto idx_c = NewValueNode(count); |
|
|
|
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count)); |
|
|
|
idx_c->set_abstract(aptr); |
|
|
|
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value}); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(node->func_graph()); |
|
|
|
@@ -300,6 +341,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr |
|
|
|
new_node = ErasePartialNode(cnode); |
|
|
|
} else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) { |
|
|
|
new_node = ConvertDictGetItemToTupleGetItem(cnode); |
|
|
|
} else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) { |
|
|
|
new_node = ConvertDictSetItemToTupleSetItem(cnode); |
|
|
|
} else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) { |
|
|
|
new_node = EraseMakeDictNode(cnode); |
|
|
|
} else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) { |
|
|
|
|