You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vmimpl.h 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_VM_VMIMPL_H_
  19. #define MINDSPORE_CCSRC_VM_VMIMPL_H_
  20. #include <set>
  21. #include <unordered_map>
  22. #include <memory>
  23. #include <vector>
  24. #include "ir/anf.h"
  25. #include "ir/manager.h"
  26. #include "ir/meta_tensor.h"
  27. #include "utils/base_ref.h"
  28. namespace mindspore {
  29. namespace compile {
  30. using AnfNodePtrList = std::vector<AnfNodePtr>;
  31. using AnfNodePtrToBaseRefMap = std::unordered_map<AnfNodePtr, BaseRef>;
  32. using AnfNodePtrToAnfNodePtrMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
  33. using FuncGraphPtrToBaseRefMap = std::unordered_map<FuncGraphPtr, BaseRef>;
  34. using TensorList = std::vector<tensor::TensorPtr>;
  35. class Closure;
  36. using ClosurePtr = std::shared_ptr<Closure>;
  37. class VMFrame;
  38. using VMFramePtr = std::shared_ptr<VMFrame>;
  39. using VMFramePtrList = std::vector<VMFramePtr>;
  40. class VM;
  41. using VMPtr = std::shared_ptr<VM>;
  42. class Partial;
  43. using PartialPtr = std::shared_ptr<Partial>;
  44. using RunFunc = std::function<VectorRef(const VectorRef& args)>;
  45. using RunFuncPtr = std::shared_ptr<RunFunc>;
  46. using SuccFunc = std::function<AnfNodePtrList(AnfNodePtr node)>;
  47. class VMImpl {
  48. public:
  49. virtual VectorRef RunGraph(const FuncGraphPtr& fg, const VectorRef& args) = 0;
  50. virtual ~VMImpl() = default;
  51. };
  52. // An execution frame.
  53. // This holds the state for an application of a graph. The nodes list
  54. // must contain free variables of graphs encountered before the
  55. // graph themselves.
  56. // You can index a frame with a node to get its value in the context
  57. // of this frame (if it has already been evaluated).
  58. // Attributes:
  59. // nodes: list of nodes remaining to execute
  60. // values: Mapping of node to their values in this application
  61. // closure: values for the closure if the current application is a closure
  62. class VMFrame {
  63. public:
  64. VMFrame(const AnfNodePtrList& nodes, const AnfNodePtrToBaseRefMap& values, const AnfNodePtrToBaseRefMap& closure);
  65. const BaseRef operator[](const AnfNodePtr& node);
  66. const AnfNodePtrList& todo() const { return todo_; }
  67. AnfNodePtrToBaseRefMap& values() { return values_; }
  68. virtual ~VMFrame() = default;
  69. AnfNodePtrToBaseRefMap values_;
  70. private:
  71. AnfNodePtrList todo_;
  72. AnfNodePtrToBaseRefMap closure_;
  73. };
  74. // Representation of a closure.
  75. class Closure : public Base {
  76. public:
  77. Closure(const FuncGraphPtr& func_graph, const AnfNodePtrToBaseRefMap& values);
  78. BaseRef operator()(const VectorRef& args);
  79. const VMPtr& vm() const { return vm_; }
  80. void set_vm(const VMPtr& vm) { vm_ = vm; }
  81. const FuncGraphPtr& func_graph() const { return func_graph_; }
  82. const AnfNodePtrToBaseRefMap& values() const { return values_; }
  83. virtual ~Closure() = default;
  84. MS_DECLARE_PARENT(Closure, Base)
  85. private:
  86. FuncGraphPtr func_graph_;
  87. AnfNodePtrToBaseRefMap values_;
  88. VMPtr vm_;
  89. };
  90. // Representation of a partial application.
  91. class Partial : public Base {
  92. public:
  93. Partial(const BaseRef& fn, const VectorRef& args, const VMPtr& vm);
  94. BaseRef operator()(const VectorRef& nodes);
  95. const BaseRef& fn() const { return fn_; }
  96. const VectorRef& args() const { return args_; }
  97. virtual ~Partial() = default;
  98. MS_DECLARE_PARENT(Partial, Base)
  99. private:
  100. BaseRef fn_;
  101. VectorRef args_;
  102. VMPtr vm_;
  103. };
  104. // Virtual Machine interface.
  105. class VM : public std::enable_shared_from_this<VM>, public VMImpl {
  106. public:
  107. SetRef ComputeFvs(const FuncGraphPtr& func_graph);
  108. void AcquireGraph(const FuncGraphPtr& func_graph);
  109. VectorRef ExportSequence(const VectorRef& seq);
  110. BaseRef ExportPrimitive(const PrimitivePtr&) const { return kAnyValue; }
  111. ClosurePtr ExportClosure(const ClosurePtr& clos);
  112. // Return an object that executes `fg` when called on arguments.
  113. ClosurePtr ExportGraph(const FuncGraphPtr& fg);
  114. BaseRef ExportObj(const BaseRef& obj) const;
  115. BaseRef Export(const BaseRef& value);
  116. // Run a graph.
  117. // This will evaluate the passed-in graph and return the
  118. // resulting value.
  119. BaseRef Evaluate(const FuncGraphPtr& func_graph, const VectorRef& args,
  120. const AnfNodePtrToBaseRefMap& closure = AnfNodePtrToBaseRefMap());
  121. // Return a visitor for the graph.
  122. SuccFunc SuccVm(const FuncGraphPtr& func_graph);
  123. // Call the `fn` object.
  124. // `fn` can be anything that would be valid as the first element of an apply.
  125. BaseRef Call(const BaseRef& fn, const VectorRef& args);
  126. BaseRef _Call(const BaseRef& graph, const VectorRef& args);
  127. ClosurePtr MakeClosure(const FuncGraphPtr& func_graph, const VMFramePtr& frame);
  128. BaseRef DispatchCall(const AnfNodePtr& node, const VMFramePtr& frame, const BaseRef& fn, const VectorRef& args);
  129. BaseRef HandleNode(const AnfNodePtr& node, const VMFramePtr& frame);
  130. VectorRef RunGraph(const FuncGraphPtr& fg, const VectorRef& args) override;
  131. private:
  132. FuncGraphManagerPtr manager_;
  133. FuncGraphPtrToBaseRefMap vars_;
  134. };
  135. extern BaseRef RunOperation(const PrimitivePtr& prim, const VectorRef& args);
  136. } // namespace compile
  137. } // namespace mindspore
  138. #endif // MINDSPORE_CCSRC_VM_VMIMPL_H_