You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

zip_operation.cc 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "operator/composite/zip_operation.h"
  19. #include <algorithm>
  20. #include <utility>
  21. #include "pipeline/static_analysis/abstract_value.h"
  22. #include "ir/anf.h"
  23. #include "pipeline/static_analysis/dshape.h"
  24. #include "pipeline/static_analysis/param_validator.h"
  25. #include "operator/cc_implementations.h"
  26. #include "optimizer/opt.h"
  27. #include "utils/symbolic.h"
  28. #include "./common.h"
  29. #include "pybind_api/api_register.h"
  30. namespace mindspore {
  31. // namespace to support composite operators definition
  32. namespace prim {
  33. using mindspore::abstract::AbstractBase;
  34. using mindspore::abstract::AbstractTuple;
  35. FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  36. // zip operation:
  37. // input: tuple arguments
  38. // output: tuple of items of input iterated on every input
  39. if (args_spec_list.size() == 0) {
  40. MS_LOG(EXCEPTION) << "zip arguments input should not be empty";
  41. }
  42. auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
  43. MS_EXCEPTION_IF_NULL(abs);
  44. return abs->isa<AbstractTuple>();
  45. });
  46. if (!is_all_tuple) {
  47. MS_LOG(EXCEPTION) << "zip input args should be tuple";
  48. }
  49. auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(),
  50. [](const AbstractBasePtr &x, const AbstractBasePtr &y) {
  51. return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
  52. });
  53. FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
  54. ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  55. for (size_t idx = 0; idx < args_spec_list.size(); idx++) {
  56. (void)ret_graph->add_parameter();
  57. }
  58. // generate tuple output of ziped arguments input
  59. std::vector<AnfNodePtr> make_tuple_nodes;
  60. make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
  61. for (size_t idx = 0; idx < (*min_abs)->cast<AbstractTuplePtr>()->size(); idx++) {
  62. std::vector<AnfNodePtr> make_tuple_zip_nodes;
  63. make_tuple_zip_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
  64. for (size_t arg_idx = 0; arg_idx < args_spec_list.size(); arg_idx++) {
  65. std::vector<AnfNodePtr> tuple_get_item_nodes{NewValueNode(prim::kPrimTupleGetItem),
  66. ret_graph->parameters()[arg_idx], NewValueNode(SizeToInt(idx))};
  67. auto tuple_get_item_op = ret_graph->NewCNode(tuple_get_item_nodes);
  68. make_tuple_zip_nodes.push_back(tuple_get_item_op);
  69. }
  70. auto make_tuple_zip_op = ret_graph->NewCNode(make_tuple_zip_nodes);
  71. make_tuple_nodes.push_back(make_tuple_zip_op);
  72. }
  73. ret_graph->set_output(ret_graph->NewCNode(make_tuple_nodes));
  74. return ret_graph;
  75. }
  76. REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) {
  77. (void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m,
  78. "ZipOperation_")
  79. .def(py::init<std::string &>());
  80. }));
  81. } // namespace prim
  82. } // namespace mindspore