/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "frontend/optimizer/clean.h" #include #include #include #include "debug/trace.h" #include "frontend/operator/composite/composite.h" #include "pipeline/jit/parse/resolve.h" namespace mindspore { /* namespace to support opt */ namespace opt { using mindspore::abstract::AbstractAttribute; using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractRowTensor; using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractSparseTensor; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractUndetermined; inline void CheckInputsSize(size_t actual_size, size_t expect_size, const std::string &op_name) { if (actual_size != expect_size) { MS_LOG(EXCEPTION) << op_name << " should have " << expect_size << " inputs, but got " << actual_size; } } static AbstractBasePtr Reabs(const AbstractBasePtr &t) { if (t == nullptr) { return nullptr; } if (t->isa()) { auto abs_class = dyn_cast(t); AbstractBasePtrList baselist; auto attributes = abs_class->attributes(); (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), [](const AbstractAttribute &item) { return item.second; }); return std::make_shared(baselist); } if (t->isa()) { auto abs_dict = dyn_cast(t); AbstractBasePtrList baselist; auto elements = abs_dict->elements(); (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), [](const AbstractAttribute &item) { return item.second; }); return std::make_shared(baselist); } return nullptr; } static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) { if (t == nullptr) { return nullptr; } if (t->isa()) { auto abs_list = dyn_cast(t); return std::make_shared(abs_list->elements()); } if (t->isa()) { auto abs_sparse = dyn_cast(t); std::vector abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()}; return std::make_shared(abstract_list); } if (t->isa()) { auto abs_row_tensor = dyn_cast(t); std::vector abstract_list{abs_row_tensor->indices(), abs_row_tensor->values(), abs_row_tensor->dense_shape()}; return std::make_shared(abstract_list); } return nullptr; } AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); const auto &inputs = node->inputs(); // Inputs should be [getattr, data, attribute] const size_t expect_inputs_size = 3; CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node)); constexpr size_t data_index = 1; constexpr size_t attribute_index = 2; AnfNodePtr data = inputs[data_index]; AnfNodePtr cons = inputs[attribute_index]; MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(cons); auto dt = data->abstract(); if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) { return nullptr; } if (!dt->isa()) { MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << "."; } auto cons_is_str = IsValueNode(cons); auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; auto ct = dyn_cast(dt); const auto &cmap = ct->attributes(); int64_t count = 0; for (auto &item : cmap) { if (cons_is_str && item.first == cons_str) { break; } count++; } auto idx_c = NewValueNode(count); AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); idx_c->set_abstract(aptr); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); } AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); // Inputs should be [dict_getitem, dict, item] const auto &inputs = node->inputs(); const size_t expect_inputs_size = 3; CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node)); constexpr size_t data_index = 1; constexpr size_t cons_index = 2; AnfNodePtr data = inputs[data_index]; AnfNodePtr cons = inputs[cons_index]; MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(cons); auto dt = data->abstract(); MS_EXCEPTION_IF_NULL(dt); if (!dt->isa()) { MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name(); } auto cons_is_str = IsValueNode(cons); auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; auto ct = dyn_cast(dt); const auto &cmap = ct->elements(); int64_t count = 0; for (auto &item : cmap) { if (cons_is_str && item.first == cons_str) { break; } count++; } auto idx_c = NewValueNode(count); AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); idx_c->set_abstract(aptr); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); } AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); // Inputs should be [dict_setitem, dict, item, value] const auto &inputs = node->inputs(); const size_t expect_inputs_size = 4; CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node)); const size_t data_index = 1; const size_t cons_index = 2; const size_t item_value_index = 3; AnfNodePtr data = inputs[data_index]; AnfNodePtr cons = inputs[cons_index]; AnfNodePtr item_value = inputs[item_value_index]; MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(cons); auto dt = data->abstract(); MS_EXCEPTION_IF_NULL(dt); if (!dt->isa()) { MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name(); } auto cons_is_str = IsValueNode(cons); auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; auto ct = dyn_cast(dt); const auto &cmap = ct->elements(); int64_t count = 0; for (auto &item : cmap) { if (cons_is_str && item.first == cons_str) { break; } count++; } if (LongToSize(count) >= cmap.size()) { // for dictionary set, if the key does not exist, we should create a new item auto tuple_add_op = std::make_shared("tuple_add"); auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value}); return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item}); } auto idx_c = NewValueNode(count); AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); idx_c->set_abstract(aptr); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value}); } AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); std::vector inputs; inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr; (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end()); return node->func_graph()->NewCNode(inputs); } AnfNodePtr ErasePartialNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); const auto &inputs = node->inputs(); // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; const size_t min_inputs_size = 2; if (inputs.size() < min_inputs_size) { MS_LOG(EXCEPTION) << "Partial should have at least 2 inputs, but got " << inputs.size(); } std::vector args(inputs.begin() + 2, inputs.end()); auto oper = inputs[1]; if (IsPrimitive(oper, prim::kPrimMakeRecord)) { if (args.size() == 1) { return NewValueNode(prim::kPrimMakeTuple); } if (args.size() > 1) { std::vector new_inputs; new_inputs.emplace_back(NewValueNode(prim::kPrimPartial)); new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end()); MS_EXCEPTION_IF_NULL(node->func_graph()); return node->func_graph()->NewCNode(new_inputs); } } return nullptr; } AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); std::vector inputs; inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items; (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end()); return node->func_graph()->NewCNode(inputs); } AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); const auto &inputs = node->inputs(); // Inputs should be [list_getitem, list, item] constexpr size_t expect_input_size = 3; CheckInputsSize(inputs.size(), expect_input_size, GetCNodeFuncName(node)); constexpr size_t real_input_index = 1; constexpr size_t index_input_index = 2; AnfNodePtr data = inputs[real_input_index]; AnfNodePtr cons = inputs[index_input_index]; MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(cons); auto cons_node = cons->cast(); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); } AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); const auto &inputs = node->inputs(); // Inputs should be [list_setitem, list, index, item] const size_t expect_inputs_size = 4; CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node)); const size_t data_index = 1; const size_t cons_index = 2; const size_t value_index = 3; AnfNodePtr data = inputs[data_index]; AnfNodePtr cons = inputs[cons_index]; AnfNodePtr value = inputs[value_index]; return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); } AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); const auto &inputs = node->inputs(); const size_t expect_inputs_size = 3; CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node)); return inputs[2]; } AnfNodePtr EraseDictGetValues(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); const auto &inputs = node->inputs(); const size_t expect_inputs_size = 2; CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node)); return inputs[1]; } AnfNodePtr EraseDictItems(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); const auto &inputs = node->inputs(); const size_t expect_inputs_size = 2; CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node)); const auto &tmp = inputs[0]->cast(); MS_EXCEPTION_IF_NULL(tmp); MS_EXCEPTION_IF_NULL(tmp->value()->cast()); ValuePtrList keys = tmp->value()->cast()->value(); std::vector outer_node{NewValueNode(prim::kPrimMakeList)}; for (size_t i = 0; i < keys.size(); ++i) { std::vector inner_node; inner_node.push_back(NewValueNode(prim::kPrimMakeTuple)); inner_node.push_back(NewValueNode(keys[i])); inner_node.push_back(NewCNode( std::vector{NewValueNode(prim::kPrimTupleGetItem), inputs[1], NewValueNode(i)}, node->func_graph())); outer_node.push_back(NewCNode(inner_node, node->func_graph())); } return NewCNode(outer_node, node->func_graph()); } AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); const auto &inputs = node->inputs(); // Inputs should be [make_keyword_arg, key, value] constexpr size_t expect_input_size = 3; constexpr size_t value_inputs_index = 2; CheckInputsSize(inputs.size(), expect_input_size, GetCNodeFuncName(node)); return inputs[value_inputs_index]; } AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); const auto &inputs = node->inputs(); // Inputs should be [extract_keyword_arg, arg, key] const size_t expect_inputs_size = 3; CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node)); constexpr size_t key_index = 2; return inputs[key_index]; } ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) { const int64_t DEPTH_MAX = 5; if (depth > DEPTH_MAX) { MS_LOG(EXCEPTION) << "List nesting is not allowed more than 6 levels."; } std::vector elements; for (const auto &it : value_list->value()) { ValuePtr value = nullptr; if (it->isa()) { value = ConvertValueListToValueTuple(it->cast(), depth + 1); } else { value = it; } elements.push_back(value); } return std::make_shared(elements); } AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(node); ValuePtr value = node->value(); auto value_list = value->cast(); MS_EXCEPTION_IF_NULL(value_list); int64_t depth = 0; return std::make_shared(ConvertValueListToValueTuple(value_list, depth)); } // Convert class to Tuple // Convert getattr to getitem // Convert make_record to make_tuple bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); bool changed = false; // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var AnfNodeSet all_node = manager->all_nodes(); for (auto &node : all_node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); AnfNodePtr new_node = nullptr; if (IsValueNode(node)) { new_node = NewValueNode(prim::kPrimMakeTuple); } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) { new_node = ConvertGetAttrToTupleGetItem(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) { new_node = ConvertMakeRecordToMakeTuple(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) { new_node = ErasePartialNode(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) { new_node = ConvertDictGetItemToTupleGetItem(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) { new_node = ConvertDictSetItemToTupleSetItem(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimDictGetValues)) { new_node = EraseDictGetValues(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) { new_node = EraseMakeDictNode(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) { new_node = EraseMakeKeywordArgNode(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) { new_node = EraseExtractKeywordArg(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimDictItems)) { new_node = EraseDictItems(cnode); } if (new_node != nullptr) { new_node->set_abstract(node->abstract()); MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); (void)manager->Replace(node, new_node); changed = true; } } for (auto &node : manager->all_nodes()) { auto ret = Reabs(node->abstract()); if (ret) { MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " << ret->ToString(); node->set_abstract(ret); if (ret->cast()->size() > 0) { changed = true; } } } return changed; } AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); std::vector inputs; inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); // Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items; (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end()); return node->func_graph()->NewCNode(inputs); } AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node, const int64_t &index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); const auto &inputs = node->inputs(); // Inputs should be [sparse_getattr, sparse] constexpr size_t expect_input_index = 2; CheckInputsSize(inputs.size(), expect_input_index, GetCNodeFuncName(node)); constexpr size_t sparse_index = 1; AnfNodePtr sparse = inputs[sparse_index]; MS_EXCEPTION_IF_NULL(sparse); auto cons_node = NewValueNode(index); AbstractBasePtr aptr = std::make_shared(std::make_shared(index)); cons_node->set_abstract(aptr); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node}); } bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); bool changed = false; // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var auto all_node = manager->all_nodes(); for (auto &node : all_node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); AnfNodePtr new_node = nullptr; if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { new_node = ConvertMakeListToMakeTuple(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) { new_node = ConvertListGetItemToTupleGetItem(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) { new_node = ConvertListSetItemToTupleSetItem(cnode); } else if (IsValueNode(node)) { new_node = ConvertValueListNodeToValueTupleNode(node->cast()); } else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) || IsPrimitiveCNode(node, prim::kPrimMakeRowTensor)) { new_node = ConvertMakeSparseToMakeTuple(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) || IsPrimitiveCNode(node, prim::kPrimRowTensorGetIndices)) { constexpr int64_t indices_index = 0; new_node = ConvertSparseGetAttrToTupleGetItem(cnode, indices_index); } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) || IsPrimitiveCNode(node, prim::kPrimRowTensorGetValues)) { constexpr int64_t value_index = 1; new_node = ConvertSparseGetAttrToTupleGetItem(cnode, value_index); } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) || IsPrimitiveCNode(node, prim::kPrimRowTensorGetDenseShape)) { constexpr int64_t shape_index = 2; new_node = ConvertSparseGetAttrToTupleGetItem(cnode, shape_index); } if (new_node != nullptr) { new_node->set_abstract(node->abstract()); MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); (void)manager->Replace(node, new_node); changed = true; } } for (auto &node : manager->all_nodes()) { auto ret = AdaptAbs(node->abstract()); if (ret) { MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " << ret->ToString(); node->set_abstract(ret); changed = true; } } return changed; } } // namespace opt } // namespace mindspore