| @@ -1,6 +1,6 @@ | |||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| from functools import lru_cache | from functools import lru_cache | ||||
| from typing import Iterable, Optional, Sequence, Tuple, Union | |||||
| from typing import Iterable, List, Optional, Sequence, Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| @@ -36,6 +36,7 @@ __all__ = [ | |||||
| "full_like", | "full_like", | ||||
| "gather", | "gather", | ||||
| "linspace", | "linspace", | ||||
| "meshgrid", | |||||
| "ones", | "ones", | ||||
| "ones_like", | "ones_like", | ||||
| "repeat", | "repeat", | ||||
| @@ -1205,3 +1206,49 @@ def cumsum(inp: Tensor, axis: int): | |||||
| assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" | assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" | ||||
| op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) | op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) | ||||
| return apply(op, inp)[0] | return apply(op, inp)[0] | ||||
| def meshgrid(*inputs: Tensor, indexing: str = "xy") -> List[Tensor]: | |||||
| r"""Returns coordinate matrices from coordinate vectors. | |||||
| Args: | |||||
| inputs: an arbitrary number of one-dimensional tensors representing grid | |||||
| coordinates. Each input should have the same numeric data type. | |||||
| indexing: Cartesian ``'xy'`` or matrix ``'ij'`` indexing of output. | |||||
| If provided zero or one one-dimensional vector(s) (i.e., the zero- and one-dimensional | |||||
| cases, respectively), the indexing keyword has no effect and should be ignored. | |||||
| Returns: | |||||
| out: list of N tensors, where N is the number of provided one-dimensional input tensors. | |||||
| Each returned tensor must have rank N. For N one-dimensional tensors having lengths ``Ni = len(xi)``, | |||||
| * if matrix indexing ``ij``, then each returned tensor must have the shape ``(N1, N2, N3, ..., Nn)``. | |||||
| * if Cartesian indexing ``xy``, then each returned tensor must have shape ``(N2, N1, N3, ..., Nn)``. | |||||
| Accordingly, for the two-dimensional case with input one-dimensional tensors of length ``M`` and ``N``, | |||||
| if matrix indexing ``ij``, then each returned tensor must have shape ``(M, N)``, and, if Cartesian indexing ``xy``, | |||||
| then each returned tensor must have shape ``(N, M)``. | |||||
| Similarly, for the three-dimensional case with input one-dimensional tensor of length ``M``, ``N``, and ``P``, | |||||
| if matrix indexing ``ij``, then each returned tensor must have shape ``(M, N, P)``, and, if Cartesian indexing ``xy``, | |||||
| then each returned tensor must have shape ``(N, M, P)``. | |||||
| Each returned tensor should have the same data type as the input tensors. | |||||
| Examples: | |||||
| >>> nx, ny = (3, 2) | |||||
| >>> x = F.linspace(0, 1, nx) | |||||
| >>> y = F.linspace(0, 1, ny) | |||||
| >>> xv, yv = F.meshgrid(x, y) | |||||
| >>> xv | |||||
| Tensor([[0. 0.5 1. ] | |||||
| [0. 0.5 1. ]], device=xpux:0) | |||||
| >>> yv | |||||
| Tensor([[0. 0. 0.] | |||||
| [1. 1. 1.]], device=xpux:0) | |||||
| """ | |||||
| op = builtin.MeshGrid(indexing) | |||||
| return apply(op, *inputs) | |||||
| @@ -1,13 +1,129 @@ | |||||
| #include <numeric> | |||||
| #include "megbrain/graph/helper.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
| #include "megbrain/opr/io.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| #include "megbrain/graph/helper.h" | |||||
| #include "../op_trait.h" | #include "../op_trait.h" | ||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| namespace meshgrid { | |||||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
| return SmallVector<VarNode::LayoutConstraintCallback>(inputs.size()); | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| for (size_t i = 0; i < inputs.size() - 1; i++) { | |||||
| mgb_assert(inputs[i].layout.dtype == inputs[i + 1].layout.dtype); | |||||
| mgb_assert(inputs[i].comp_node == inputs[i + 1].comp_node); | |||||
| } | |||||
| auto&& op = def.cast_final_safe<MeshGrid>(); | |||||
| mgb_assert(op.indexing == "xy" || op.indexing == "ij"); | |||||
| bool success = true; | |||||
| SmallVector<size_t> shp; | |||||
| for (size_t i = 0; i < inputs.size(); i++) { | |||||
| mgb_assert(inputs[i].layout.ndim <= 1); | |||||
| if (inputs[i].layout.ndim == 0) { | |||||
| success = false; | |||||
| } | |||||
| shp.push_back(inputs[i].layout.total_nr_elems()); | |||||
| } | |||||
| if (op.indexing == "xy" and shp.size() >= 2) { | |||||
| std::swap(shp[0], shp[1]); | |||||
| } | |||||
| TensorShape tshp(shp); | |||||
| SmallVector<LogicalTensorDesc> descs; | |||||
| for (size_t i = 0; i < inputs.size(); i++) { | |||||
| if (success) { | |||||
| descs.push_back( | |||||
| {TensorLayout(tshp, inputs[0].layout.dtype), inputs[0].comp_node}); | |||||
| } else { | |||||
| descs.push_back( | |||||
| {TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}); | |||||
| } | |||||
| } | |||||
| return {descs, success}; | |||||
| } | |||||
| VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = def.cast_final_safe<MeshGrid>(); | |||||
| std::vector<size_t> indexs(inputs.size()); | |||||
| std::iota(indexs.begin(), indexs.end(), 0); | |||||
| auto cn = inputs[0]->comp_node(); | |||||
| auto graph = inputs[0]->owner_graph(); | |||||
| if (op.indexing == "xy") { | |||||
| if (indexs.size() >= 2) { | |||||
| std::swap(indexs[0], indexs[1]); | |||||
| } | |||||
| } else { | |||||
| mgb_assert(op.indexing == "ij", "meshgrid only support \"ij\" or \"xy\""); | |||||
| } | |||||
| VarNodeArray shps; | |||||
| for (size_t ind = 0; ind < inputs.size(); ind++) { | |||||
| auto&& inp = inputs[indexs[ind]]; | |||||
| shps.push_back(opr::GetVarShape::make(inp).node()); | |||||
| } | |||||
| VarNode* tshp = opr::Concat::make(shps, 0, cn).node(); | |||||
| VarNodeArray results; | |||||
| auto t_ndim = inputs.size(); | |||||
| for (size_t ind = 0; ind < inputs.size(); ind++) { | |||||
| auto axis = indexs[ind]; | |||||
| HostTensorND hv = HostTensorND(cn, {t_ndim}, dtype::Int32()); | |||||
| auto* ptr = hv.ptr<dt_int32>(); | |||||
| std::fill_n(ptr, t_ndim, 1); | |||||
| ptr[axis] = -1; | |||||
| auto shp = opr::ImmutableTensor::make(*graph, hv, cn).node(); | |||||
| auto tmp = opr::Reshape::make(inputs[ind], shp, axis).node(); | |||||
| results.push_back(opr::Broadcast::make(tmp, tshp).node()); | |||||
| } | |||||
| return results; | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto&& op = def.cast_final_safe<MeshGrid>(); | |||||
| TensorShape tshp; | |||||
| TensorShape view_shp; | |||||
| tshp.ndim = inputs.size(); | |||||
| view_shp.ndim = inputs.size(); | |||||
| std::vector<size_t> indexs(inputs.size()); | |||||
| std::iota(indexs.begin(), indexs.end(), 0); | |||||
| if (op.indexing == "xy") { | |||||
| if (indexs.size() >= 2) { | |||||
| std::swap(indexs[0], indexs[1]); | |||||
| } | |||||
| } else { | |||||
| mgb_assert(op.indexing == "ij", "meshgrid only support \"ij\" or \"xy\""); | |||||
| } | |||||
| for (size_t ind = 0; ind < inputs.size(); ind++) { | |||||
| auto&& inp = inputs[indexs[ind]]; | |||||
| mgb_assert(inp->layout().ndim <= 1); | |||||
| tshp[ind] = inp->layout().total_nr_elems(); | |||||
| view_shp[ind] = 1; | |||||
| } | |||||
| SmallVector<TensorPtr> grids; | |||||
| for (size_t i = 0; i < inputs.size(); i++) { | |||||
| auto&& src = inputs[i]; | |||||
| TensorLayout layout; | |||||
| view_shp[indexs[i]] = src->layout().total_nr_elems(); | |||||
| mgb_assert(src->layout().try_reshape(layout, view_shp)); | |||||
| layout = layout.broadcast(tshp); | |||||
| view_shp[indexs[i]] = 1; | |||||
| grids.push_back(Tensor::make(src->blob(), src->offset(), layout)); | |||||
| } | |||||
| return grids; | |||||
| } | |||||
| OP_TRAIT_REG(MeshGrid, MeshGrid) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .get_input_layout_constraint(get_input_layout_constraint) | |||||
| .fallback(); | |||||
| } // namespace meshgrid | |||||
| namespace broadcast { | namespace broadcast { | ||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | ||||
| @@ -211,7 +327,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| tshp, tshp_nd->get_value().proxy_to_default_cpu()); | tshp, tshp_nd->get_value().proxy_to_default_cpu()); | ||||
| } | } | ||||
| if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | ||||
| mgb_assert(tshp[op.axis] == -1); | |||||
| tshp[op.axis] = 1; | tshp[op.axis] = 1; | ||||
| tshp[op.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | tshp[op.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | ||||
| } | } | ||||
| @@ -237,7 +352,6 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
| tshp, inputs[1]->get_value().proxy_to_default_cpu()); | tshp, inputs[1]->get_value().proxy_to_default_cpu()); | ||||
| } | } | ||||
| if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | ||||
| mgb_assert(tshp[op.axis] == -1); | |||||
| tshp[op.axis] = 1; | tshp[op.axis] = 1; | ||||
| tshp[op.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); | tshp[op.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); | ||||
| } | } | ||||
| @@ -250,7 +364,7 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
| return layout_checker; | return layout_checker; | ||||
| } | } | ||||
| OP_TRAIT_REG(Reshape, Reshape) | |||||
| OP_TRAIT_REG(Reshape, Reshape, opr::Reshape) | |||||
| .apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
| @@ -1,7 +1,7 @@ | |||||
| 905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py | 905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py | ||||
| e35e13523f43b7bea4034a0bf75937b7 ../../src/core/include/megbrain/ir/ops.td | |||||
| 240dccd6f8d42cadfd08c6ca90fe61b1 generated/opdef.h.inl | |||||
| a79a4058ff18ffd9593ee5db3deef6c4 generated/opdef.cpp.inl | |||||
| 83c179ee7416824fbfab978a097cd4d3 generated/opdef.py.inl | |||||
| 86f70b1052331130f5e4c0ca53e68423 generated/opdef.cpy.inl | |||||
| 40708c56b1f05fdb7d06cc097a300330 ../../src/core/include/megbrain/ir/ops.td | |||||
| 9f3af118c7fe8d0c9db433825d5ad77b generated/opdef.h.inl | |||||
| 4041e44a8ba3cca3b3affa1ed9ed44a2 generated/opdef.cpp.inl | |||||
| 319e1d170c989fe793a4e9c45decefc4 generated/opdef.py.inl | |||||
| 26a18a7593566128ecce76e8f74dcc5d generated/opdef.cpy.inl | |||||
| 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h | 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h | ||||
| @@ -4672,6 +4672,43 @@ OP_TRAIT_REG(MatrixMul, MatrixMul) | |||||
| .props(MatrixMul_props_impl) | .props(MatrixMul_props_impl) | ||||
| .make_name(MatrixMul_make_name_impl); | .make_name(MatrixMul_make_name_impl); | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(MeshGrid); | |||||
| namespace { | |||||
| size_t MeshGrid_hash_impl(const OpDef& def_) { | |||||
| auto&& op_ = def_.cast_final_safe<MeshGrid>(); | |||||
| static_cast<void>(op_); | |||||
| size_t val = mgb::hash(op_.dyn_typeinfo()); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.indexing)); | |||||
| return val; | |||||
| } | |||||
| bool MeshGrid_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { | |||||
| auto &&a_ = lhs_.cast_final_safe<MeshGrid>(), | |||||
| &&b_ = rhs_.cast_final_safe<MeshGrid>(); | |||||
| static_cast<void>(a_); | |||||
| static_cast<void>(b_); | |||||
| if (a_.indexing != b_.indexing) return false; | |||||
| return true; | |||||
| } | |||||
| std::vector<std::pair<const char*, std::string>> MeshGrid_props_impl(const OpDef& def_) { | |||||
| auto&& op_ = def_.cast_final_safe<MeshGrid>(); | |||||
| static_cast<void>(op_); | |||||
| std::vector<std::pair<const char*, std::string>> props_; | |||||
| props_.emplace_back("indexing", op_.indexing); | |||||
| return props_; | |||||
| } | |||||
| std::string MeshGrid_make_name_impl(const OpDef& def_) { | |||||
| auto&& op_ = def_.cast_final_safe<MeshGrid>(); | |||||
| static_cast<void>(op_); | |||||
| return "MeshGrid"; | |||||
| } | |||||
| } // anonymous namespace | |||||
| OP_TRAIT_REG(MeshGrid, MeshGrid) | |||||
| .hash(MeshGrid_hash_impl) | |||||
| .is_same_st(MeshGrid_is_same_st_impl) | |||||
| .props(MeshGrid_props_impl) | |||||
| .make_name(MeshGrid_make_name_impl); | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(MeshIndexing); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(MeshIndexing); | ||||
| namespace { | namespace { | ||||
| @@ -12467,6 +12467,95 @@ void _init_py_MatrixMul(py::module m) { | |||||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MatrixMul::typeinfo(), &py_type).second); | mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MatrixMul::typeinfo(), &py_type).second); | ||||
| } | } | ||||
| PyOpDefBegin(MeshGrid) // { | |||||
| static PyGetSetDef py_getsetters[]; | |||||
| static PyMethodDef tp_methods[]; | |||||
| static PyObject* getstate(PyObject* self, PyObject*) { | |||||
| auto& opdef = reinterpret_cast<PyOp(MeshGrid)*>(self)->inst(); | |||||
| static_cast<void>(opdef); | |||||
| std::unordered_map<std::string, py::object> state { | |||||
| {"indexing", serialization<decltype(opdef.indexing)>::dump(opdef.indexing)} | |||||
| }; | |||||
| return py::cast(state).release().ptr(); | |||||
| } | |||||
| static PyObject* setstate(PyObject* self, PyObject* args) { | |||||
| PyObject* dict = PyTuple_GetItem(args, 0); | |||||
| if (!dict) return NULL; | |||||
| auto state = py::cast<std::unordered_map<std::string, py::object>>(dict); | |||||
| auto& opdef = reinterpret_cast<PyOp(MeshGrid)*>(self)->inst(); | |||||
| static_cast<void>(opdef); | |||||
| { | |||||
| auto&& iter = state.find("indexing"); | |||||
| if (iter != state.end()) { | |||||
| opdef.indexing = serialization<decltype(opdef.indexing)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| Py_RETURN_NONE; | |||||
| } | |||||
| static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||||
| // }; | |||||
| PyOpDefEnd(MeshGrid) | |||||
| int PyOp(MeshGrid)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||||
| static const char* kwlist[] = {"indexing", "scope", NULL}; | |||||
| PyObject *indexing = NULL, *scope = NULL; | |||||
| if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast<char**>(kwlist), &indexing, &scope)) | |||||
| return -1; | |||||
| if (indexing) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(MeshGrid)*>(self)->inst().indexing = | |||||
| py::cast<decltype(MeshGrid::indexing)>(py::handle(indexing)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (scope) { | |||||
| try { | |||||
| reinterpret_cast<PyOp(OpDef)*>(self)->op | |||||
| ->set_scope(py::cast<std::string>(py::handle(scope))); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| PyGetSetDef PyOp(MeshGrid)::py_getsetters[] = { | |||||
| {const_cast<char*>("indexing"), py_get_generic(MeshGrid, indexing), py_set_generic(MeshGrid, indexing), const_cast<char*>("indexing"), NULL}, | |||||
| {NULL} /* Sentinel */ | |||||
| }; | |||||
| PyMethodDef PyOp(MeshGrid)::tp_methods[] = { | |||||
| {const_cast<char*>("__getstate__"), PyOp(MeshGrid)::getstate, METH_NOARGS, "MeshGrid getstate"}, | |||||
| {const_cast<char*>("__setstate__"), PyOp(MeshGrid)::setstate, METH_VARARGS, "MeshGrid setstate"}, | |||||
| {NULL} /* Sentinel */ | |||||
| }; | |||||
| void _init_py_MeshGrid(py::module m) { | |||||
| using py_op = PyOp(MeshGrid); | |||||
| auto& py_type = PyOpType(MeshGrid); | |||||
| py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
| py_type.tp_name = "megengine.core._imperative_rt.ops.MeshGrid"; | |||||
| py_type.tp_basicsize = sizeof(PyOp(MeshGrid)); | |||||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
| py_type.tp_doc = "MeshGrid"; | |||||
| py_type.tp_base = &PyOpType(OpDef); | |||||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
| py_type.tp_new = py_new_generic<py_op>; | |||||
| py_type.tp_init = py_op::py_init; | |||||
| py_type.tp_methods = py_op::tp_methods; | |||||
| py_type.tp_getset = py_op::py_getsetters; | |||||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
| PyType_Modified(&py_type); | |||||
| m.add_object("MeshGrid", reinterpret_cast<PyObject*>(&py_type)); | |||||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MeshGrid::typeinfo(), &py_type).second); | |||||
| } | |||||
| PyOpDefBegin(MeshIndexing) // { | PyOpDefBegin(MeshIndexing) // { | ||||
| static PyGetSetDef py_getsetters[]; | static PyGetSetDef py_getsetters[]; | ||||
| static PyMethodDef tp_methods[]; | static PyMethodDef tp_methods[]; | ||||
| @@ -18594,6 +18683,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { | |||||
| _init_py_MagicMindRuntime(m); \ | _init_py_MagicMindRuntime(m); \ | ||||
| _init_py_MatrixInverse(m); \ | _init_py_MatrixInverse(m); \ | ||||
| _init_py_MatrixMul(m); \ | _init_py_MatrixMul(m); \ | ||||
| _init_py_MeshGrid(m); \ | |||||
| _init_py_MeshIndexing(m); \ | _init_py_MeshIndexing(m); \ | ||||
| _init_py_NMSKeep(m); \ | _init_py_NMSKeep(m); \ | ||||
| _init_py_NvOf(m); \ | _init_py_NvOf(m); \ | ||||
| @@ -1262,6 +1262,15 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| class MeshGrid : public OpDefImplBase<MeshGrid> { | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
| public: | |||||
| std::string indexing; | |||||
| MeshGrid() = default; | |||||
| MeshGrid(std::string indexing_, std::string scope_ = {}): indexing(indexing_) { set_scope(scope_); } | |||||
| }; | |||||
| class MeshIndexing : public OpDefImplBase<MeshIndexing> { | class MeshIndexing : public OpDefImplBase<MeshIndexing> { | ||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | MGB_DYN_TYPE_OBJ_FINAL_DECL; | ||||
| @@ -1365,6 +1365,13 @@ MatrixMulInst | |||||
| .def_readwrite("dimA", &MatrixMul::dimA) | .def_readwrite("dimA", &MatrixMul::dimA) | ||||
| .def_readwrite("dimB", &MatrixMul::dimB); | .def_readwrite("dimB", &MatrixMul::dimB); | ||||
| py::class_<MeshGrid, std::shared_ptr<MeshGrid>, OpDef> MeshGridInst(m, "MeshGrid"); | |||||
| MeshGridInst | |||||
| .def(py::init<std::string, std::string>(), py::arg("indexing"), py::arg("scope") = {}) | |||||
| .def(py::init<>()) | |||||
| .def_readwrite("indexing", &MeshGrid::indexing); | |||||
| py::class_<MeshIndexing, std::shared_ptr<MeshIndexing>, OpDef> MeshIndexingInst(m, "MeshIndexing"); | py::class_<MeshIndexing, std::shared_ptr<MeshIndexing>, OpDef> MeshIndexingInst(m, "MeshIndexing"); | ||||
| MeshIndexingInst | MeshIndexingInst | ||||
| @@ -515,4 +515,9 @@ def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> { | |||||
| let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}]; | let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}]; | ||||
| } | } | ||||
| def MeshGrid: MgbHashableOp<"MeshGrid"> { | |||||
| let extraArguments = (ins | |||||
| MgbStringAttr:$indexing | |||||
| ); | |||||
| } | |||||
| #endif // MGB_OPS | #endif // MGB_OPS | ||||