| @@ -12,6 +12,7 @@ | |||
| #include "./grad.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "megbrain/utils/mempool.h" | |||
| #include "range/v3/all.hpp" | |||
| @@ -21,6 +22,9 @@ namespace views = ranges::views; | |||
| namespace mgb::imperative::python { | |||
| using scoped_disable = ApplyContext::scoped_disable; | |||
| using Flags = Tensor::Flags; | |||
| namespace { | |||
| struct GradSlotWeakPtr { | |||
| @@ -78,6 +82,21 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||
| return result; | |||
| } | |||
| struct BackwardContext { | |||
| PyTypeObject* pytype = nullptr; | |||
| auto wrap_tensor(std::shared_ptr<Tensor> t) { | |||
| if (pytype) { | |||
| return TensorWrapper::make(pytype, std::move(t)); | |||
| } | |||
| return TensorWrapper::make(std::move(t)); | |||
| } | |||
| auto wrap_tensor(Tensor* t) { | |||
| return wrap_tensor(t->shared_from_this()); | |||
| } | |||
| }; | |||
| struct BackwardGraphWithClosure { | |||
| std::shared_ptr<BackwardGraphResult> backward_graph; | |||
| SmallVector<std::shared_ptr<Tensor>> closure; | |||
| @@ -119,7 +138,7 @@ struct BackwardGraphWithClosure { | |||
| } | |||
| template <typename T, typename R> | |||
| void operator()(T&& grads, R&& receiver) { | |||
| void operator()(BackwardContext&, T&& grads, R&& receiver) { | |||
| Tensor* args[closure.size() + grads.size()]; | |||
| size_t nargs = 0; | |||
| for (auto&& t : closure) { | |||
| @@ -143,7 +162,7 @@ struct BackwardGraphWithClosure { | |||
| ApplyContext ctx; | |||
| ctx.op = backward_graph->backward; | |||
| ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; | |||
| ctx.flags = is_tracing ? Flags::TRACE : 0; | |||
| ctx.nargs = nargs; | |||
| ctx.args = args; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| @@ -174,6 +193,47 @@ struct BackwardGraphWithClosure { | |||
| } | |||
| }; | |||
| struct PythonBackward { | |||
| py::object pyfunc; | |||
| size_t input_size; | |||
| PythonBackward(py::object f, size_t nin) | |||
| : pyfunc(f), input_size(nin) {} | |||
| template <typename T, typename R> | |||
| void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||
| auto args = py::tuple(grads.size()); | |||
| for (size_t i = 0; i < grads.size(); ++i) { | |||
| auto&& g = grads[i]; | |||
| args[i] = g ? ctx.wrap_tensor(g) : py::none(); | |||
| } | |||
| auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); | |||
| if (input_grads.is_none()) return; | |||
| if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { | |||
| if (input_size != 1) { | |||
| throw py::value_error("custom grad rule returned wrong number of grads"); | |||
| } | |||
| receiver(0, tw->m_tensor); | |||
| return; | |||
| } | |||
| if (py::len(input_grads) != input_size) { | |||
| throw py::value_error("custom grad rule returned wrong number of grads"); | |||
| } | |||
| for (auto [i, g] : views::enumerate(input_grads)) { | |||
| if (g.is_none()) continue; | |||
| auto* tw = TensorWrapper::try_cast(g.ptr()); | |||
| if (!tw) { | |||
| throw py::type_error("custom grad rule returned non-tensor"); | |||
| } | |||
| receiver(i, tw->m_tensor); | |||
| } | |||
| } | |||
| static constexpr bool input_has_grad(size_t) {return true;} | |||
| static constexpr bool output_requires_grad(size_t) {return true;} | |||
| static constexpr bool output_captured(size_t) {return true;} | |||
| }; | |||
| } // namespace | |||
| struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { | |||
| @@ -210,7 +270,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> { | |||
| // same length as inputs (of forward op) | |||
| SmallVector<GradSlotProducerPtr> dsts; | |||
| // encapsules actual function to compute gradient | |||
| std::variant<std::monostate, BackwardGraphWithClosure> backward; | |||
| std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward> backward; | |||
| // a flag used during backward | |||
| bool in_ref_keeper = false; | |||
| @@ -268,6 +328,30 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra | |||
| return outputs; | |||
| } | |||
| apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||
| auto* op = ctx.op->try_cast_final<GenericPyOp>(); | |||
| py::tuple pyin(ctx.nargs); | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | |||
| } | |||
| auto grad_rule = py::getattr(op->obj, "_grad_rule"); | |||
| auto pyret = (scoped_disable(Flags::GRAD), | |||
| py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr))); // comma expression | |||
| auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | |||
| ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | |||
| if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | |||
| return {tw->m_tensor}; | |||
| } | |||
| apply_result_t ret; | |||
| ret.reserve(py::len(outputs)); | |||
| for (auto&& i : outputs) { | |||
| auto* tw = TensorWrapper::try_cast(i.ptr()); | |||
| mgb_assert(tw); | |||
| ret.push_back(tw->m_tensor); | |||
| } | |||
| return ret; | |||
| } | |||
| } // namespace | |||
| apply_result_t apply_grad(ApplyContext& ctx) { | |||
| @@ -290,21 +374,23 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||
| // cleanup stale grad info | |||
| // under what condition? | |||
| tensor->m_grad_info = {}; | |||
| tensor->m_flags &= ~Tensor::Flags::GRAD; | |||
| tensor->m_flags &= ~Flags::GRAD; | |||
| } | |||
| } else { | |||
| tensor->m_flags &= ~Tensor::Flags::GRAD; | |||
| tensor->m_flags &= ~Flags::GRAD; | |||
| } | |||
| } | |||
| ctx.flags &= ~Tensor::Flags::GRAD; | |||
| ctx.flags &= ~Flags::GRAD; | |||
| if (!grad_key) { | |||
| return apply(ctx); | |||
| } | |||
| GradFnHelper grad_fn_holder; | |||
| auto outputs = backward_graph_grad_rule(ctx, grad_fn_holder); | |||
| auto outputs = ctx.op->same_type<GenericPyOp>() ? | |||
| python_grad_rule(ctx, grad_fn_holder) : | |||
| backward_graph_grad_rule(ctx, grad_fn_holder); | |||
| auto& grad_fn = grad_fn_holder.grad_fn; | |||
| if (!grad_fn) { | |||
| @@ -341,7 +427,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||
| grad_info.grad_fn = grad_fn; | |||
| grad_info.idx = i; | |||
| grad_info.insert_after(grad_key->free_vars_head); | |||
| outputs[i]->m_flags |= Tensor::Flags::GRAD; | |||
| outputs[i]->m_flags |= Flags::GRAD; | |||
| } | |||
| } | |||
| } | |||
| @@ -357,7 +443,7 @@ void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | |||
| if (nargs != 2) { | |||
| throw py::type_error("expect 2 arguments"); | |||
| } | |||
| auto* tw = TensorWrapper::cast_safe(args[0]); | |||
| auto* tw = TensorWrapper::try_cast(args[0]); | |||
| if (!tw) { | |||
| throw py::type_error("argument 1 must be Tensor"); | |||
| } | |||
| @@ -390,14 +476,15 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) { | |||
| grad_fn->key = shared_from_this(); | |||
| grad_fn->slots.resize(1); | |||
| tensor->m_grad_info.insert_after(free_vars_head); | |||
| tensor->m_flags |= Tensor::Flags::GRAD; | |||
| tensor->m_flags |= Flags::GRAD; | |||
| } | |||
| tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); | |||
| } | |||
| void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) { | |||
| template<typename T> | |||
| void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) { | |||
| if (!grad) { | |||
| grad = std::forward<decltype(delta)>(delta); | |||
| grad = std::forward<T>(delta); | |||
| return; | |||
| } | |||
| static ApplyContext ctx; | |||
| @@ -409,7 +496,7 @@ void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) | |||
| ctx.args = args; | |||
| ctx.flags = grad->m_flags | delta->m_flags; | |||
| if (is_tracing) { | |||
| ctx.flags |= Tensor::Flags::TRACE; | |||
| ctx.flags |= Flags::TRACE; | |||
| } | |||
| grad = apply(ctx)[0]; | |||
| } | |||
| @@ -440,6 +527,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
| } | |||
| } | |||
| BackwardContext bctx{pytype}; | |||
| std::vector<std::shared_ptr<GradFn>> ref_keeper; | |||
| ref_keeper.reserve(tape.size()); | |||
| // back-propagation in reverse order | |||
| @@ -456,7 +544,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
| mgb_assert(0); | |||
| } else { | |||
| auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();}); | |||
| backward(std::forward<decltype(grads)>(grads), grad_receiver); | |||
| backward(bctx, std::forward<decltype(grads)>(grads), grad_receiver); | |||
| } | |||
| }, grad_fn->backward); | |||
| @@ -14,6 +14,7 @@ | |||
| #include "megbrain/imperative.h" | |||
| #include "megbrain/imperative/ops/backward_graph.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include <Python.h> | |||
| @@ -245,6 +246,35 @@ void _init_py_backward_graph(py::module m) { | |||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); | |||
| } | |||
| struct PyOpBase : PyOpDef { | |||
| static PyTypeObject py_type; | |||
| static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) { | |||
| auto* obj = type->tp_alloc(type, 0); | |||
| if (obj) { | |||
| auto* self = reinterpret_cast<PyOpBase*>(obj); | |||
| new(&self->op) decltype(self->op); | |||
| } | |||
| return obj; | |||
| } | |||
| }; | |||
| PyTypeObject PyOpBase::py_type; | |||
| void _init_py_op_base(py::module m) { | |||
| using py_op = PyOpBase; | |||
| auto& py_type = PyOpBase::py_type; | |||
| py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase"; | |||
| py_type.tp_basicsize = sizeof(py_op); | |||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| py_type.tp_doc = "PyOpBase"; | |||
| py_type.tp_base = &PyOpType(OpDef); | |||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
| py_type.tp_new = py_op::tp_new; | |||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||
| m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type)); | |||
| } | |||
| /*********** end of hand-write opdefs **************/ | |||
| // auto generated opdefs | |||
| @@ -260,9 +290,16 @@ bool type_caster<OpDef>::load(handle src, bool convert) { | |||
| return false; | |||
| } | |||
| value = reinterpret_cast<PyOp(OpDef)*>(obj)->op; | |||
| if (!value) { | |||
| // opdef only defined in Python | |||
| value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src)); | |||
| } | |||
| return true; | |||
| } | |||
| handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { | |||
| if (auto* pyop = op.try_cast_final<GenericPyOp>()) { | |||
| return object(pyop->obj).release(); | |||
| } | |||
| PyTypeObject* pytype; | |||
| auto& c2p = PyOp(OpDef)::ctype2pytype; | |||
| auto&& iter = c2p.find(op.dyn_typeinfo()); | |||
| @@ -283,5 +320,6 @@ handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { | |||
| void init_ops(py::module m) { | |||
| _init_py_op_def(m); | |||
| _init_py_backward_graph(m); | |||
| _init_py_op_base(m); | |||
| INIT_ALL_OP(m) | |||
| } | |||
| @@ -11,6 +11,7 @@ | |||
| #include "megbrain/dtype.h" | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "./tensor.h" | |||
| #include "./grad.h" | |||
| @@ -22,10 +23,12 @@ | |||
| #include <pybind11/numpy.h> | |||
| #include <pybind11/operators.h> | |||
| #include <range/v3/all.hpp> | |||
| #include <unordered_map> | |||
| namespace py = pybind11; | |||
| namespace views = ranges::views; | |||
| namespace mgb::imperative::python { | |||
| @@ -69,21 +72,45 @@ SET_UNSET_PROP(compiled) | |||
| bool skip_tracing = false; | |||
| Tensor::flags_t ApplyContext::global_disable = 0; | |||
| apply_result_t apply(ApplyContext& ctx) { | |||
| // emulating scalar should be put to specific op's apply, e.g., | |||
| // elementwise, reduce, typecvt. Currently it's still handled at python | |||
| // side. It could be move to C++ side if it has an impact on performance | |||
| if (ctx.flags & Tensor::Flags::SCALAR) { | |||
| auto flags = ctx.flags & ~ApplyContext::global_disable; | |||
| if (flags & Tensor::Flags::SCALAR) { | |||
| // TODO: emulate scalar | |||
| } | |||
| if (ctx.flags & Tensor::Flags::GRAD) { | |||
| if (flags & Tensor::Flags::GRAD) { | |||
| return apply_grad(ctx); | |||
| } | |||
| if (ctx.flags & Tensor::Flags::TRACE) { | |||
| if (flags & Tensor::Flags::TRACE) { | |||
| return apply_trace(ctx); | |||
| } else { | |||
| if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) { | |||
| py::tuple pyin(ctx.nargs); | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | |||
| } | |||
| auto f = py::getattr(op->obj, "_default_rule"); | |||
| auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr)); | |||
| if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) { | |||
| return {tw->m_tensor}; | |||
| } | |||
| apply_result_t ret; | |||
| ret.reserve(py::len(pyout)); | |||
| for (auto&& i : pyout) { | |||
| auto* tw = TensorWrapper::try_cast(i.ptr()); | |||
| mgb_assert(tw); | |||
| ret.push_back(tw->m_tensor); | |||
| } | |||
| return ret; | |||
| } | |||
| SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| handles[i] = ctx.args[i]->m_handle.get(); | |||
| @@ -125,12 +152,13 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
| SmallVector<Tensor*, 64> tensors(nargs); | |||
| ctx.args = &tensors[0]; | |||
| ctx.nargs = nargs; | |||
| ctx.pytype = pytype; | |||
| if (strstr(op->ob_type->tp_name, "BackwardGraph")) { | |||
| ctx.backward = true; | |||
| } | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| if (TensorWrapper* tw = TensorWrapper::cast_safe(args[i])) { | |||
| if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
| auto* t = tensors[i] = tw->m_tensor.get(); | |||
| ctx.flags |= t->m_flags; | |||
| } else { | |||
| @@ -166,7 +194,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
| if (nargs == 0) { | |||
| throw py::type_error("too few arguments"); | |||
| } | |||
| if (auto* t = cast_safe(tup[0].ptr())) { | |||
| if (auto* t = try_cast(tup[0].ptr())) { | |||
| if (nargs > 1) { | |||
| throw py::type_error("expect 1 argument"); | |||
| } | |||
| @@ -211,7 +239,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
| auto ret = pyf(*tup); | |||
| auto py_ret = py::reinterpret_borrow<py::list>(ret); | |||
| if (auto* t = cast_safe(py_ret[0].ptr())) { | |||
| if (auto* t = try_cast(py_ret[0].ptr())) { | |||
| m_tensor = t->m_tensor; | |||
| } | |||
| return; | |||
| @@ -349,7 +377,7 @@ PyObject* TensorWrapper::varnode() { | |||
| } | |||
| void TensorWrapper::reset(PyObject* tensor) { | |||
| TensorWrapper* t = TensorWrapper::cast_safe(tensor); | |||
| TensorWrapper* t = TensorWrapper::try_cast(tensor); | |||
| if (!t) { | |||
| throw py::type_error("expect Tensor"); | |||
| } | |||
| @@ -446,7 +474,7 @@ uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||
| } | |||
| } | |||
| // Returns the data type with sufficient size to hold all types of | |||
| // Returns the data type with sufficient size to hold all types of | |||
| // category `cat` in the list `types`. | |||
| PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||
| // Return value: New reference | |||
| @@ -507,7 +535,7 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
| if (handle == Py_None) continue; | |||
| TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
| if (tw) { | |||
| mgb::DType type = tw->m_tensor->dtype(); | |||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||
| @@ -562,7 +590,7 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||
| CompNode cn; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
| TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
| if (tw) { | |||
| if (!valid) { | |||
| cn = tw->m_tensor->comp_node(); | |||
| @@ -124,7 +124,7 @@ struct TensorWrapper { | |||
| friend wrap_t; | |||
| inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();} | |||
| inline static TensorWrapper* cast_safe(PyObject* op) { | |||
| inline static TensorWrapper* try_cast(PyObject* op) { | |||
| if (!wrap_t::type().isinstance(op)) return nullptr; | |||
| return cast(op); | |||
| } | |||
| @@ -173,11 +173,26 @@ struct TensorWrapper { | |||
| PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | |||
| struct ApplyContext { | |||
| static Tensor::flags_t global_disable; | |||
| Tensor::flags_t flags; | |||
| std::shared_ptr<OpDef> op; | |||
| Tensor*const* args; | |||
| size_t nargs; | |||
| PyTypeObject* pytype = nullptr; | |||
| bool backward = false; | |||
| class scoped_disable : NonCopyableObj { | |||
| Tensor::flags_t saved_flags; | |||
| public: | |||
| scoped_disable(Tensor::flags_t flags) : saved_flags(ApplyContext::global_disable) { | |||
| ApplyContext::global_disable |= flags; | |||
| } | |||
| ~scoped_disable() { | |||
| ApplyContext::global_disable = saved_flags; | |||
| } | |||
| }; | |||
| }; | |||
| using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | |||
| @@ -85,7 +85,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
| // assumption: python function always returns PyList | |||
| auto tup = py::reinterpret_borrow<py::list>(ret); | |||
| for (auto i = 0; i < tup.size(); i++) { | |||
| auto tw = TensorWrapper::cast_safe(tup[i].ptr()); | |||
| auto tw = TensorWrapper::try_cast(tup[i].ptr()); | |||
| outputs.emplace_back(tw->m_tensor); | |||
| } | |||
| return outputs; | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file imperative/src/impl/ops/utility.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb::imperative { | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); | |||
| } // namespace mgb::imperative | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/ops/utility.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megbrain/utils/hash.h" | |||
| #include <pybind11/pybind11.h> | |||
| namespace mgb::imperative { | |||
| struct GenericPyOp final : OpDefImplBase<GenericPyOp> { | |||
| pybind11::object obj; | |||
| GenericPyOp(pybind11::object obj_) : obj(std::move(obj_)) {}; | |||
| size_t hash() const override { | |||
| return pybind11::hash(obj); | |||
| } | |||
| bool is_same_st(const Hashable& rhs) const override { | |||
| return obj.equal(static_cast<const GenericPyOp&>(rhs).obj); | |||
| } | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| }; | |||
| } // namespace mgb::imperative | |||