| @@ -104,24 +104,27 @@ from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheO | |||
| from .version import __version__ | |||
| logger = get_logger(__name__) | |||
| ngpus = get_device_count("gpu") | |||
| supported_sm_versions = re.findall(r"sm_(\d+)", _get_supported_sm_versions()) | |||
| for idx in range(ngpus): | |||
| prop = get_cuda_device_property(idx) | |||
| cur_sm = str(prop.major * 10 + prop.minor) | |||
| if not cur_sm in supported_sm_versions: | |||
| logger.warning( | |||
| "{} with CUDA capability sm_{} is not compatible with the current MegEngine installation. The current MegEngine install supports CUDA {} {}. If you want to use the {} with MegEngine, please check the instructions at https://github.com/MegEngine/MegEngine/blob/master/scripts/cmake-build/BUILD_README.md".format( | |||
| prop.name, | |||
| cur_sm, | |||
| "capabilities" if len(supported_sm_versions) > 1 else "capability", | |||
| " ".join(["sm_" + v for v in supported_sm_versions]), | |||
| prop.name, | |||
| def _check_sm_version(): | |||
| cur_logger = get_logger(__name__) | |||
| ngpus = get_device_count("gpu") | |||
| supported_sm_versions = re.findall(r"sm_(\d+)", _get_supported_sm_versions()) | |||
| for idx in range(ngpus): | |||
| prop = get_cuda_device_property(idx) | |||
| cur_sm = str(prop.major * 10 + prop.minor) | |||
| if not cur_sm in supported_sm_versions: | |||
| cur_logger.warning( | |||
| "{} with CUDA capability sm_{} is not compatible with the current MegEngine installation. The current MegEngine install supports CUDA {} {}. If you want to use the {} with MegEngine, please check the instructions at https://github.com/MegEngine/MegEngine/blob/master/scripts/cmake-build/BUILD_README.md".format( | |||
| prop.name, | |||
| cur_sm, | |||
| "capabilities" if len(supported_sm_versions) > 1 else "capability", | |||
| " ".join(["sm_" + v for v in supported_sm_versions]), | |||
| prop.name, | |||
| ) | |||
| ) | |||
| ) | |||
| _check_sm_version() | |||
| _set_fork_exec_path_for_timed_func( | |||
| sys.executable, | |||
| os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | |||
| @@ -16,6 +16,7 @@ from ..core._imperative_rt.core2 import ( | |||
| adaptive_pool2d_cpp, | |||
| apply, | |||
| dtype_promotion, | |||
| pixel_shuffle_cpp, | |||
| ) | |||
| from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
| from ..core.ops import builtin | |||
| @@ -1849,16 +1850,7 @@ def _get_layerPixelShuffle(device, dtype, dim_order): | |||
| return layerPixelShuffle | |||
| def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||
| """ | |||
| Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of | |||
| shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero | |||
| or more batch dimensions. | |||
| :param inp: input tensor. | |||
| :param upscale_factor: upscale factor of pixel_shuffle. | |||
| :return: output tensor. | |||
| """ | |||
| def layerPixelShuffle_traceable(inp, upscale_factor): | |||
| assert upscale_factor > 0, "upscale_factor should larger than 0" | |||
| assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3" | |||
| assert ( | |||
| @@ -1899,6 +1891,19 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||
| return outvar | |||
| def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||
| """ | |||
| Rearranges elements in a tensor of shape `(..., C * r^2, H, W)` to a tensor of | |||
| shape `(..., C, H * r, W * r)`, where `r` is an upscale factor, where `...` is | |||
| zero or more batch dimensions. | |||
| :param inp: input tensor. | |||
| :param upscale_factor: upscale factor of pixel_shuffle. | |||
| :return: output tensor. | |||
| """ | |||
| return pixel_shuffle_cpp(inp, upscale_factor, layerPixelShuffle_traceable) | |||
| from .quantized import conv_bias_activation # isort:skip | |||
| from .loss import * # isort:skip | |||
| from .metric import * # isort:skip | |||
| @@ -349,6 +349,28 @@ std::optional<ValueRefList> removeAxis_grad_rule( | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| std::optional<ValueRefList> pixelShuffle_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& pixelShuffle = op.cast_final_safe<PixelShuffle>(); | |||
| mgb_assert(inputs.size() == 1); | |||
| bool flag = inputs_require_grad[0]; | |||
| auto&& grad_op = PixelShuffleBackward::make(pixelShuffle.factor); | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| SmallVector<ValueRef> ret(1); | |||
| if (grad && flag_) { | |||
| ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
| } | |||
| return ret; | |||
| }); | |||
| maker.finalize(); | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| std::optional<ValueRefList> fastpathcopy_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| @@ -382,6 +404,8 @@ struct Init { | |||
| RemoveAxis::typeinfo(), removeAxis_grad_rule); | |||
| CustomBackward::register_grad_rule( | |||
| FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||
| CustomBackward::register_grad_rule( | |||
| PixelShuffle::typeinfo(), pixelShuffle_grad_rule); | |||
| } | |||
| } _; | |||
| @@ -438,6 +438,7 @@ WRAP_FUNC_PY35(batched_matmul_cpp); | |||
| WRAP_FUNC_PY35(convert_single_value_cpp); | |||
| WRAP_FUNC_PY35(convert_inputs_cpp); | |||
| WRAP_FUNC_PY35(astensor1d_cpp); | |||
| WRAP_FUNC_PY35(pixel_shuffle_cpp); | |||
| #undef WRAP_FUNC_PY35 | |||
| #define MGE_PY_INTERFACE(NAME, FUNC) \ | |||
| { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | |||
| @@ -595,6 +596,7 @@ void init_tensor(py::module m) { | |||
| MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), | |||
| MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), | |||
| MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp), | |||
| MGE_PY_INTERFACE(pixel_shuffle_cpp, pixel_shuffle_cpp), | |||
| {nullptr, nullptr, 0, nullptr}}; | |||
| for (auto&& def : method_defs) { | |||
| if (def.ml_meth != nullptr) { | |||
| @@ -1378,7 +1378,7 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||
| } else { | |||
| auto&& inp_ndim = get_ndim_safe(inp_hdl); | |||
| ndim += inp_ndim.first; | |||
| unknown_ndim &= ~inp_ndim.second; | |||
| unknown_ndim &= !inp_ndim.second; | |||
| } | |||
| for (size_t i = 0; i < axis.size(); ++i) { | |||
| if (axis[i] < 0) { | |||
| @@ -1446,6 +1446,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2)); | |||
| return ret[0]; | |||
| } | |||
| py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
| py::object obj = _expand_args(args); | |||
| py::list lis; | |||
| @@ -1562,6 +1563,19 @@ py::object _batched_matmul_cpp( | |||
| } | |||
| } | |||
| py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) { | |||
| if (enable_fastpath(inp) && PyLong_Check(val.ptr())) { | |||
| std::shared_ptr<OpDef> op = PixelShuffle::make(val.cast<int32_t>()); | |||
| py::object Op = py::cast(op); | |||
| PyObject* p[2] = {Op.ptr(), inp.ptr()}; | |||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2)); | |||
| return ret[0]; | |||
| } else { | |||
| // fallback to traceable subgraph implement | |||
| return func(inp, val); | |||
| } | |||
| } | |||
| PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _make_shape_tuple(args[0]).release().ptr(); | |||
| @@ -1632,6 +1646,13 @@ PyObject* adaptive_pool2d_cpp(PyObject* self, PyObject* const* args, size_t narg | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _pixel_shuffle_cpp(args[0], args[1], args[2]).release().ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _Const(args[0], args[1], args[2], args[3]).release().ptr(); | |||
| @@ -40,4 +40,6 @@ PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs | |||
| PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| } // namespace mgb::imperative::python | |||
| @@ -462,3 +462,19 @@ def test_dot(): | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) | |||
| def test_pixel_shuffle(): | |||
| x = np.random.rand(2, 3, 16, 3, 4).astype("float32") | |||
| x = mge.Tensor(x) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| def f(x): | |||
| p = F.pixel_shuffle(x, 2) | |||
| return p * p | |||
| y = f(x) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(2 * x.numpy(), x.grad.numpy()) | |||
| @@ -255,6 +255,7 @@ def test_conv_bias_int4(): | |||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | |||
| @pytest.mark.require_ngpu(1) | |||
| @pytest.mark.skipif( | |||
| get_cuda_compute_capability(0) < 61, | |||
| reason="does not support int8 when gpu compute capability less than 6.1", | |||
| @@ -290,6 +290,7 @@ def test_deformable_ps_roi_pooling(): | |||
| check_pygraph_dump(fwd, [inp, rois, trans], [result]) | |||
| @pytest.mark.require_ngpu(1) | |||
| @pytest.mark.skipif( | |||
| get_cuda_compute_capability(0) < 61, | |||
| reason="does not support int8 when gpu compute capability less than 6.1", | |||
| @@ -0,0 +1,157 @@ | |||
| #include "../op_trait.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| using namespace megdnn; | |||
| namespace mgb::imperative { | |||
| namespace pixel_shuffle { | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& op = def.cast_final_safe<PixelShuffle>(); | |||
| auto&& src = inputs[0]; | |||
| auto&& layout = src->layout(); | |||
| mgb_assert( | |||
| layout.ndim >= 3, | |||
| "the input dimension of pixel_shuffle should be larger than or equal to 3"); | |||
| size_t idx = layout.ndim - 3; | |||
| mgb_assert( | |||
| layout[idx] % (op.factor * op.factor) == 0, | |||
| "the -3 dimension should be divided by (upscale_factor ** 2)"); | |||
| TensorLayout tlayout; | |||
| TensorShape tshp; // {N, C, r, r, H, W} | |||
| TensorShape vshp; // {..., C, Hr, Wr} | |||
| tshp.ndim = 6; | |||
| vshp.ndim = layout.ndim; | |||
| tshp[0] = 1; | |||
| for (size_t i = 0; i < idx; ++i) { | |||
| tshp[0] *= layout[i]; | |||
| vshp[i] = layout[i]; | |||
| } | |||
| tshp[1] = layout[idx] / (op.factor * op.factor); | |||
| tshp[2] = tshp[3] = op.factor; | |||
| tshp[4] = layout[idx + 1]; | |||
| tshp[5] = layout[idx + 2]; | |||
| vshp[idx] = tshp[1]; | |||
| vshp[idx + 1] = layout[idx + 1] * op.factor; | |||
| vshp[idx + 2] = layout[idx + 2] * op.factor; | |||
| tlayout = layout.reshape(tshp).dimshuffle({0, 1, 4, 2, 5, 3}); | |||
| TensorPtr out = Tensor::make(src->blob(), src->offset(), tlayout); | |||
| out->to_contiguous_inplace(); // relayout | |||
| tlayout = out->layout().reshape(vshp); | |||
| return {Tensor::make(out->blob(), out->offset(), tlayout)}; | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto&& op = def.cast_final_safe<PixelShuffle>(); | |||
| mgb_assert(op.factor > 0, "upscale_factor should be larger than 0"); | |||
| auto&& src = inputs[0]; | |||
| if (src.layout.ndim == 0) { | |||
| return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; | |||
| } | |||
| mgb_assert( | |||
| src.layout.ndim >= 3, | |||
| "the input dimension of pixel_shuffle should be larger than or equal to 3"); | |||
| size_t idx = src.layout.ndim - 3; | |||
| mgb_assert( | |||
| src.layout[idx] % (op.factor * op.factor) == 0, | |||
| "the -3 dimension should be divided by (upscale_factor ** 2)"); | |||
| TensorShape tshp; | |||
| tshp.ndim = src.layout.ndim; | |||
| for (size_t i = 0; i < idx; ++i) { | |||
| tshp[i] = src.layout[i]; | |||
| } | |||
| tshp[idx] = src.layout[idx] / (op.factor * op.factor); | |||
| tshp[idx + 1] = src.layout[idx + 1] * op.factor; | |||
| tshp[idx + 2] = src.layout[idx + 2] * op.factor; | |||
| return {{{TensorLayout(tshp, src.layout.dtype), src.comp_node}}, true}; | |||
| } | |||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
| layout_checker[0] = [](const TensorLayout& layout) { | |||
| return layout.is_contiguous(); | |||
| }; | |||
| return layout_checker; | |||
| } | |||
| OP_TRAIT_REG(PixelShuffle, PixelShuffle) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .get_input_layout_constraint(get_input_layout_constraint) | |||
| .fallback(); | |||
| } // namespace pixel_shuffle | |||
| namespace pixel_shuffle_backward { | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& op = def.cast_final_safe<PixelShuffleBackward>(); | |||
| auto&& src = inputs[0]; | |||
| auto&& layout = src->layout(); | |||
| size_t idx = layout.ndim - 3; | |||
| TensorLayout tlayout; | |||
| TensorShape tshp; // {N, C, H, r, W, r} | |||
| TensorShape vshp; // {..., Cr^2, H, W} | |||
| tshp.ndim = 6; | |||
| vshp.ndim = layout.ndim; | |||
| tshp[0] = 1; | |||
| for (size_t i = 0; i < idx; ++i) { | |||
| tshp[0] *= layout[i]; | |||
| vshp[i] = layout[i]; | |||
| } | |||
| tshp[1] = layout[idx]; | |||
| tshp[3] = tshp[5] = op.factor; | |||
| tshp[2] = layout[idx + 1] / op.factor; | |||
| tshp[4] = layout[idx + 2] / op.factor; | |||
| vshp[idx] = tshp[1] * op.factor * op.factor; | |||
| vshp[idx + 1] = tshp[2]; | |||
| vshp[idx + 2] = tshp[4]; | |||
| tlayout = layout.reshape(tshp).dimshuffle({0, 1, 3, 5, 2, 4}); | |||
| TensorPtr out = Tensor::make(src->blob(), src->offset(), tlayout); | |||
| out->to_contiguous_inplace(); // relayout | |||
| tlayout = out->layout().reshape(vshp); | |||
| return {Tensor::make(out->blob(), out->offset(), tlayout)}; | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto&& op = def.cast_final_safe<PixelShuffleBackward>(); | |||
| auto&& src = inputs[0]; | |||
| if (src.layout.ndim == 0) { | |||
| return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; | |||
| } | |||
| size_t idx = src.layout.ndim - 3; | |||
| TensorShape tshp; | |||
| tshp.ndim = src.layout.ndim; | |||
| for (size_t i = 0; i < idx; ++i) { | |||
| tshp[i] = src.layout[i]; | |||
| } | |||
| tshp[idx] = src.layout[idx] * op.factor * op.factor; | |||
| tshp[idx + 1] = src.layout[idx + 1] / op.factor; | |||
| tshp[idx + 2] = src.layout[idx + 2] / op.factor; | |||
| return {{{TensorLayout(tshp, src.layout.dtype), src.comp_node}}, true}; | |||
| } | |||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
| layout_checker[0] = [](const TensorLayout& layout) { | |||
| return layout.is_contiguous(); | |||
| }; | |||
| return layout_checker; | |||
| } | |||
| OP_TRAIT_REG(PixelShuffleBackward, PixelShuffleBackward) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .get_input_layout_constraint(get_input_layout_constraint) | |||
| .fallback(); | |||
| } // namespace pixel_shuffle_backward | |||
| } // namespace mgb::imperative | |||
| @@ -435,6 +435,18 @@ def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [CheckNonFiniteParam]>; | |||
| def FastpathCopy: MgbHashableOp<"FastpathCopy">; | |||
| def PixelShuffle: MgbHashableOp<"PixelShuffle"> { | |||
| let extraArguments = (ins | |||
| MgbI32Attr:$factor | |||
| ); | |||
| } | |||
| def PixelShuffleBackward: MgbHashableOp<"PixelShuffleBackward"> { | |||
| let extraArguments = (ins | |||
| MgbI32Attr:$factor | |||
| ); | |||
| } | |||
| def ExternOpr: MgbHashableOp<"ExternOpr"> { | |||
| let extraArguments = (ins | |||
| MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$output_shapes, | |||