GitOrigin-RevId: e4d944343d
tags/v1.1.0
| @@ -406,3 +406,16 @@ def test_clip(): | |||||
| for i in range(3): | for i in range(3): | ||||
| f(x, tensor([0]), tensor([1])) | f(x, tensor([0]), tensor([1])) | ||||
| # test returning noncontiguous tensor from trace | |||||
| def test_slice(): | |||||
| @trace | |||||
| def f(x): | |||||
| return x[:, 1::2] | |||||
| x = F.arange(8).reshape(2, 4) | |||||
| f(x) | |||||
| y = f(x) | |||||
| np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2]) | |||||
| y + y | |||||
| @@ -156,6 +156,12 @@ cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const { | |||||
| return prop; | return prop; | ||||
| } | } | ||||
| void OutputCallback::add_input_layout_constraint() { | |||||
| if (m_param.require_contiguous) { | |||||
| input(0)->add_layout_constraint_contiguous(); | |||||
| } | |||||
| } | |||||
| void OutputCallback::scn_do_execute() { | void OutputCallback::scn_do_execute() { | ||||
| if (m_use_host_value) { | if (m_use_host_value) { | ||||
| m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0))); | m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0))); | ||||
| @@ -62,6 +62,7 @@ public: | |||||
| callback_t callback; | callback_t callback; | ||||
| bool borrow = false; // do not obtain shared ownership on DeviceTensorND | bool borrow = false; // do not obtain shared ownership on DeviceTensorND | ||||
| bool prefer_host_value = false; // use host value when possible | bool prefer_host_value = false; // use host value when possible | ||||
| bool require_contiguous = true; | |||||
| }; | }; | ||||
| OutputCallback(Param param, | OutputCallback(Param param, | ||||
| const VarNodeArray& inputs, | const VarNodeArray& inputs, | ||||
| @@ -80,6 +81,7 @@ protected: | |||||
| void scn_do_execute() override; | void scn_do_execute() override; | ||||
| void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
| NodeProp* do_make_node_prop() const override; | NodeProp* do_make_node_prop() const override; | ||||
| void add_input_layout_constraint() override; | |||||
| private: | private: | ||||
| Param m_param; | Param m_param; | ||||
| mutable bool m_use_host_value; | mutable bool m_use_host_value; | ||||