GitOrigin-RevId: 29e069fb23
tags/v1.3.0
| @@ -8,6 +8,7 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # pylint: disable=redefined-builtin | # pylint: disable=redefined-builtin | ||||
| from .elemwise import * | from .elemwise import * | ||||
| from .img_proc import * | |||||
| from .math import * | from .math import * | ||||
| from .nn import * | from .nn import * | ||||
| from .tensor import * | from .tensor import * | ||||
| @@ -0,0 +1,50 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core.ops import builtin | |||||
| from ..tensor import Tensor | |||||
| __all__ = [ | |||||
| "cvt_color", | |||||
| ] | |||||
| def cvt_color(inp: Tensor, mode: str = ""): | |||||
| r""" | |||||
| Convert images from one format to another | |||||
| :param inp: input images. | |||||
| :param mode: format mode. | |||||
| :return: convert result. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.functional as F | |||||
| x = mge.tensor(np.array([[[[-0.58675045, 1.7526233, 0.10702174]]]]).astype(np.float32)) | |||||
| y = F.img_proc.cvt_color(x, mode="RGB2GRAY") | |||||
| print(y.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[[[0.86555195]]]] | |||||
| """ | |||||
| assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color" | |||||
| mode = getattr(builtin.CvtColor.Mode, mode) | |||||
| assert isinstance(mode, builtin.CvtColor.Mode) | |||||
| op = builtin.CvtColor(mode=mode) | |||||
| (out,) = apply(op, inp) | |||||
| return out | |||||
| @@ -704,3 +704,14 @@ def test_argmxx_on_inf(): | |||||
| assert all(run_argmax() >= 0) | assert all(run_argmax() >= 0) | ||||
| assert all(run_argmin() >= 0) | assert all(run_argmin() >= 0) | ||||
| def test_cvt_color(): | |||||
| def rgb2gray(rgb): | |||||
| return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) | |||||
| inp = np.random.randn(3, 3, 3, 3).astype(np.float32) | |||||
| out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32) | |||||
| x = tensor(inp) | |||||
| y = F.img_proc.cvt_color(x, mode="RGB2GRAY") | |||||
| np.testing.assert_allclose(y.numpy(), out, atol=1e-5) | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/ops/img_proc.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/opr/imgproc.h" | |||||
| #include "../op_trait.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| namespace { | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& op = static_cast<const CvtColor&>(def); | |||||
| mgb_assert(inputs.size() == 1); | |||||
| return opr::CvtColor::make(inputs[0], op.param()); | |||||
| } | |||||
| OP_TRAIT_REG(CvtColor, CvtColor) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -254,4 +254,6 @@ def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> { | |||||
| ); | ); | ||||
| } | } | ||||
| def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | |||||
| #endif // MGB_OPS | #endif // MGB_OPS | ||||