You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_func_graph.h 3.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_
  19. #define MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_
  20. #include <unordered_map>
  21. #include <string>
  22. #include <map>
  23. #include <memory>
  24. #include <vector>
  25. #include <algorithm>
  26. #include "pybind11/pybind11.h"
  27. #include "ir/dtype.h"
  28. #include "ir/anf.h"
  29. #include "ir/func_graph.h"
  30. #include "ir/signature.h"
  31. #include "pipeline/static_analysis/abstract_value.h"
  32. namespace py = pybind11;
  33. namespace mindspore {
  34. // namespace to support intermediate representation definition
  35. // Graph generator.
  36. // Can be called with a pipeline's resources and a list of argument types to
  37. // generate a graph corresponding to these types.
  38. class MetaFuncGraph : public FuncGraphBase {
  39. public:
  40. explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); }
  41. ~MetaFuncGraph() override = default;
  42. MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase);
  43. abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node);
  44. // Return normalized versions of the arguments.
  45. // By default, this returns args unchanged.
  46. virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const {
  47. return args_spec_list;
  48. }
  49. const std::vector<Signature> &signatures() const { return signatures_; }
  50. void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; }
  51. // Generate a Graph for the given abstract arguments.
  52. virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list);
  53. // Generate a Graph for this type signature.
  54. virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) {
  55. MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types.";
  56. }
  57. std::string name() { return name_; }
  58. std::string ToString() const override { return name_; }
  59. std::size_t hash() const override { return tid(); }
  60. virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; }
  61. bool operator==(const Value &other) const override {
  62. if (other.isa<MetaFuncGraph>()) {
  63. return &other == this;
  64. } else {
  65. return false;
  66. }
  67. }
  68. const bool parse_info_ = true;
  69. protected:
  70. template <typename Derived>
  71. std::shared_ptr<Derived> shared_from_base() {
  72. return std::static_pointer_cast<Derived>(shared_from_this());
  73. }
  74. std::string name_;
  75. std::vector<Signature> signatures_;
  76. std::unordered_map<TypePtrList, FuncGraphPtr, TypeListHasher, TypeListEqual> cache_;
  77. };
  78. using MetaFuncGraphPtr = std::shared_ptr<MetaFuncGraph>;
  79. } // namespace mindspore
  80. #endif // MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_