You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vmimpl.h 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_VM_VMIMPL_H_
  19. #define MINDSPORE_CCSRC_VM_VMIMPL_H_
  20. #include <set>
  21. #include <unordered_map>
  22. #include <memory>
  23. #include <vector>
  24. #include "ir/anf.h"
  25. #include "ir/manager.h"
  26. #include "ir/meta_tensor.h"
  27. #include "utils/base_ref.h"
  28. namespace mindspore {
  29. namespace compile {
  30. using AnfNodePtrList = std::vector<AnfNodePtr>;
  31. using AnfNodePtrToBaseRefMap = std::unordered_map<AnfNodePtr, BaseRef>;
  32. using AnfNodePtrToAnfNodePtrMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
  33. using FuncGraphPtrToBaseRefMap = std::unordered_map<FuncGraphPtr, BaseRef>;
  34. using TensorList = std::vector<tensor::TensorPtr>;
  35. class Closure;
  36. using ClosurePtr = std::shared_ptr<Closure>;
  37. class VMFrame;
  38. using VMFramePtr = std::shared_ptr<VMFrame>;
  39. using VMFramePtrList = std::vector<VMFramePtr>;
  40. class VM;
  41. using VMPtr = std::shared_ptr<VM>;
  42. class Partial;
  43. using PartialPtr = std::shared_ptr<Partial>;
  44. using RunFunc = std::function<VectorRef(const VectorRef &args)>;
  45. using RunFuncPtr = std::shared_ptr<RunFunc>;
  46. using SuccFunc = std::function<AnfNodePtrList(AnfNodePtr node)>;
  47. class VMImpl {
  48. public:
  49. virtual VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) = 0;
  50. virtual ~VMImpl() = default;
  51. };
  52. // An execution frame.
  53. // This holds the state for an application of a graph. The nodes list
  54. // must contain free variables of graphs encountered before the
  55. // graph themselves.
  56. // You can index a frame with a node to get its value in the context
  57. // of this frame (if it has already been evaluated).
  58. // Attributes:
  59. // nodes: list of nodes remaining to execute
  60. // values: Mapping of node to their values in this application
  61. // closure: values for the closure if the current application is a closure
  62. class VMFrame {
  63. public:
  64. VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, const AnfNodePtrToBaseRefMap &closure);
  65. const BaseRef operator[](const AnfNodePtr &node);
  66. const AnfNodePtrList &todo() const { return todo_; }
  67. AnfNodePtrToBaseRefMap &values() { return values_; }
  68. virtual ~VMFrame() = default;
  69. AnfNodePtrToBaseRefMap values_;
  70. private:
  71. AnfNodePtrList todo_;
  72. AnfNodePtrToBaseRefMap closure_;
  73. };
  74. // Representation of a closure.
  75. class Closure : public Base {
  76. public:
  77. Closure(const FuncGraphPtr &func_graph, const AnfNodePtrToBaseRefMap &values);
  78. BaseRef operator()(const VectorRef &args);
  79. const VMPtr &vm() const { return vm_; }
  80. void set_vm(const VMPtr &vm) { vm_ = vm; }
  81. const FuncGraphPtr &func_graph() const { return func_graph_; }
  82. const AnfNodePtrToBaseRefMap &values() const { return values_; }
  83. virtual ~Closure() = default;
  84. MS_DECLARE_PARENT(Closure, Base)
  85. private:
  86. FuncGraphPtr func_graph_;
  87. AnfNodePtrToBaseRefMap values_;
  88. VMPtr vm_;
  89. };
  90. // Representation of a partial application.
  91. class Partial : public Base {
  92. public:
  93. Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm);
  94. BaseRef operator()(const VectorRef &nodes);
  95. const BaseRef &fn() const { return fn_; }
  96. const VectorRef &args() const { return args_; }
  97. virtual ~Partial() = default;
  98. MS_DECLARE_PARENT(Partial, Base)
  99. private:
  100. BaseRef fn_;
  101. VectorRef args_;
  102. VMPtr vm_;
  103. };
  104. // Virtual Machine interface.
  105. class VM : public std::enable_shared_from_this<VM>, public VMImpl {
  106. public:
  107. SetRef ComputeFvs(const FuncGraphPtr &func_graph);
  108. void AcquireGraph(const FuncGraphPtr &func_graph);
  109. VectorRef ExportSequence(const VectorRef &seq);
  110. BaseRef ExportPrimitive(const PrimitivePtr &) const { return kAnyValue; }
  111. ClosurePtr ExportClosure(const ClosurePtr &clos);
  112. // Return an object that executes `fg` when called on arguments.
  113. ClosurePtr ExportGraph(const FuncGraphPtr &fg);
  114. BaseRef ExportObj(const BaseRef &obj) const;
  115. BaseRef Export(const BaseRef &value);
  116. // Run a graph.
  117. // This will evaluate the passed-in graph and return the
  118. // resulting value.
  119. BaseRef Evaluate(const FuncGraphPtr &func_graph, const VectorRef &args,
  120. const AnfNodePtrToBaseRefMap &closure = AnfNodePtrToBaseRefMap());
  121. // Return a visitor for the graph.
  122. SuccFunc SuccVm(const FuncGraphPtr &func_graph);
  123. // Call the `fn` object.
  124. // `fn` can be anything that would be valid as the first element of an apply.
  125. BaseRef Call(const BaseRef &fn, const VectorRef &args);
  126. BaseRef _Call(const BaseRef &graph, const VectorRef &args);
  127. ClosurePtr MakeClosure(const FuncGraphPtr &func_graph, const VMFramePtr &frame);
  128. BaseRef DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args);
  129. BaseRef HandleNode(const AnfNodePtr &node, const VMFramePtr &frame);
  130. VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) override;
  131. private:
  132. FuncGraphManagerPtr manager_;
  133. FuncGraphPtrToBaseRefMap vars_;
  134. };
  135. extern BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args);
  136. } // namespace compile
  137. } // namespace mindspore
  138. #endif // MINDSPORE_CCSRC_VM_VMIMPL_H_