You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_transform.cc 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. /**
  2. * Copyright 2020-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/graph_transform.h"
  17. #include <vector>
  18. #include <algorithm>
  19. #include "ir/graph_utils.h"
  20. namespace mindspore {
  21. /* namespace to support opt */
  22. namespace opt {
  23. bool FuncGraphHasTupleInput(const FuncGraphPtr &fg) {
  24. auto is_tuple = [](const AnfNodePtr &param) {
  25. return param->abstract() != nullptr && param->abstract()->isa<abstract::AbstractTuple>();
  26. };
  27. return std::any_of(fg->parameters().cbegin(), fg->parameters().cend(), is_tuple);
  28. }
  29. std::vector<AnfNodePtr> TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node,
  30. const abstract::AbstractTuplePtr &abs) {
  31. auto &elements = abs->elements();
  32. std::vector<AnfNodePtr> tuple_node_expanded;
  33. for (size_t i = 0; i < elements.size(); i++) {
  34. auto idx = NewValueNode(SizeToLong(i));
  35. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(SizeToLong(i)));
  36. idx->set_abstract(abstract_scalar);
  37. auto elem_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  38. elem_node->set_abstract(elements[i]);
  39. if (elements[i]->isa<abstract::AbstractTuple>()) {
  40. auto nodes = TransformTupleArgument(fg, elem_node, elements[i]->cast<abstract::AbstractTuplePtr>());
  41. tuple_node_expanded.insert(tuple_node_expanded.end(), nodes.begin(), nodes.end());
  42. } else {
  43. tuple_node_expanded.push_back(elem_node);
  44. }
  45. }
  46. return tuple_node_expanded;
  47. }
  48. } // namespace opt
  49. } // namespace mindspore