You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

unpack_call.cc 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "operator/composite/unpack_call.h"
  17. #include <algorithm>
  18. #include <utility>
  19. #include "./common.h"
  20. #include "pipeline/static_analysis/abstract_value.h"
  21. #include "pipeline/static_analysis/dshape.h"
  22. #include "pipeline/static_analysis/param_validator.h"
  23. #include "operator/cc_implementations.h"
  24. #include "ir/anf.h"
  25. #include "optimizer/opt.h"
  26. #include "utils/symbolic.h"
  27. #include "pybind_api/api_register.h"
  28. namespace mindspore {
  29. // namespace to support composite operators definition
  30. namespace prim {
  31. using mindspore::abstract::AbstractAttribute;
  32. using mindspore::abstract::AbstractBase;
  33. using mindspore::abstract::AbstractDictionary;
  34. using mindspore::abstract::AbstractDictionaryPtr;
  35. using mindspore::abstract::AbstractFunction;
  36. using mindspore::abstract::AbstractKeywordArg;
  37. using mindspore::abstract::AbstractTuple;
  38. using mindspore::abstract::AbstractTuplePtr;
  39. FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  40. // slice a tensor
  41. // args: tensor, slice or slice tuple
  42. const std::string op_name = std::string("UnpackCall");
  43. size_t arg_length = args_spec_list.size();
  44. if (arg_length < 2) {
  45. MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << ".";
  46. }
  47. (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
  48. auto ret_graph = std::make_shared<FuncGraph>();
  49. ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  50. AnfNodePtr fnNode = ret_graph->add_parameter();
  51. std::vector<AnfNodePtr> elems;
  52. elems.push_back(fnNode);
  53. for (size_t index = 1; index < arg_length; index++) {
  54. MS_EXCEPTION_IF_NULL(args_spec_list[index]);
  55. if (args_spec_list[index]->isa<AbstractTuple>()) {
  56. auto arg_tuple = args_spec_list[index]->cast<AbstractTuplePtr>();
  57. AnfNodePtr para_tuple = ret_graph->add_parameter();
  58. for (size_t i = 0; i < arg_tuple->size(); ++i) {
  59. elems.push_back(
  60. ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))}));
  61. }
  62. } else if (args_spec_list[index]->isa<AbstractDictionary>()) {
  63. AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast<AbstractDictionaryPtr>();
  64. AnfNodePtr para_dict = ret_graph->add_parameter();
  65. auto dict_elems = arg_dict->elements();
  66. (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems),
  67. [ret_graph, para_dict](const AbstractAttribute &item) {
  68. auto dict_get_item = ret_graph->NewCNode(
  69. {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)});
  70. return ret_graph->NewCNode(
  71. {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item});
  72. });
  73. } else {
  74. MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got "
  75. << args_spec_list[index]->ToString();
  76. }
  77. }
  78. ret_graph->set_output(ret_graph->NewCNode(elems));
  79. return ret_graph;
  80. }
  81. REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) {
  82. (void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_")
  83. .def(py::init<std::string &>());
  84. }));
  85. } // namespace prim
  86. } // namespace mindspore