GitOrigin-RevId: 16e86a21d7
tags/v1.0.0-rc1
| @@ -254,6 +254,7 @@ class trace: | |||||
| def _compile(self): | def _compile(self): | ||||
| graph = self._graph = G.Graph() | graph = self._graph = G.Graph() | ||||
| graph.options.no_force_inplace = True | |||||
| # graph.options.graph_opt_level = 0 | # graph.options.graph_opt_level = 0 | ||||
| need_reset_nodes = self._need_reset_nodes = [] | need_reset_nodes = self._need_reset_nodes = [] | ||||
| # links enforce ordering of I/O nodes | # links enforce ordering of I/O nodes | ||||
| @@ -105,6 +105,7 @@ void init_graph_rt(py::module m) { | |||||
| DEF_READWRITE(enable_grad_var_static_reshape) | DEF_READWRITE(enable_grad_var_static_reshape) | ||||
| DEF_READWRITE(enable_memory_swap) | DEF_READWRITE(enable_memory_swap) | ||||
| DEF_READWRITE(comp_node_seq_record_level) | DEF_READWRITE(comp_node_seq_record_level) | ||||
| DEF_READWRITE(no_force_inplace) | |||||
| // DEF_READWRITE(eager_evaluation) | // DEF_READWRITE(eager_evaluation) | ||||
| // DEF_READWRITE(imperative_proxy_graph) | // DEF_READWRITE(imperative_proxy_graph) | ||||
| // DEF_READWRITE(extra_vardeps) | // DEF_READWRITE(extra_vardeps) | ||||
| @@ -81,6 +81,18 @@ public: | |||||
| return m_graph; | return m_graph; | ||||
| } | } | ||||
| bool is_same_st(const Hashable& rhs) const override { | |||||
| if (!rhs.same_type<BackwardGraph>()) { | |||||
| return false; | |||||
| } | |||||
| auto& other = rhs.cast_final_safe<BackwardGraph>(); | |||||
| if (this == &other) { | |||||
| return true; | |||||
| } | |||||
| // FIXME | |||||
| return false; | |||||
| } | |||||
| private: | private: | ||||
| InternalGraph m_graph; | InternalGraph m_graph; | ||||
| }; | }; | ||||