GitOrigin-RevId: d5fbd59a30
tags/v1.6.0
| @@ -197,7 +197,11 @@ public: | |||||
| protected: | protected: | ||||
| //! get origin coord | //! get origin coord | ||||
| std::pair<float, int> get_origin_coord(float scale, int size, int idx, bool cubic=false); | |||||
| std::pair<float, int> get_cubic_coord(float scale, int idx); | |||||
| std::tuple<float, int, float, int> get_nearest_linear_coord( | |||||
| InterpolationMode imode, float scale, int size, int idx); | |||||
| //! get nearest index in src | //! get nearest index in src | ||||
| int get_nearest_src(float scale, int size, int idx); | int get_nearest_src(float scale, int size, int idx); | ||||
| @@ -6,12 +6,14 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "src/arm_common/resize/opr_impl.h" | #include "src/arm_common/resize/opr_impl.h" | ||||
| #include "src/arm_common/handle.h" | #include "src/arm_common/handle.h" | ||||
| #include "src/arm_common/resize/resize_cv.h" | #include "src/arm_common/resize/resize_cv.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace arm_common; | using namespace arm_common; | ||||
| @@ -19,9 +21,58 @@ using namespace arm_common; | |||||
| void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | ||||
| _megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
| check_exec(src.layout, dst.layout, workspace.size); | check_exec(src.layout, dst.layout, workspace.size); | ||||
| if (param().format == param::Resize::Format::NCHW || | |||||
| (src.layout[3] != 1 && src.layout[3] != 3) || | |||||
| !is_nhwc_contig_wc(src.layout)) { | |||||
| if (param().format == param::Resize::Format::NCHW44 || | |||||
| param().format == param::Resize::Format::NCHW88) { | |||||
| bool is_contiguous = | |||||
| src.layout.is_contiguous() && dst.layout.is_contiguous(); | |||||
| bool dtype_same = src.layout.dtype == dst.layout.dtype; | |||||
| bool nchw44_enable = param().format == param::Resize::Format::NCHW44 && | |||||
| src.layout.dtype == dtype::Float32(); | |||||
| bool nchw88_enable = | |||||
| param().format == param::Resize::Format::NCHW88 && | |||||
| DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false); | |||||
| bool interp_supported = | |||||
| param().imode == | |||||
| param::Resize::InterpolationMode::INTER_NEAREST || | |||||
| param().imode == param::Resize::InterpolationMode::INTER_LINEAR; | |||||
| bool is_upsample2 = | |||||
| param().imode == | |||||
| param::Resize::InterpolationMode::INTER_NEAREST && | |||||
| src.layout.shape[2] * 2 == dst.layout.shape[2] && | |||||
| src.layout.shape[3] * 2 == dst.layout.shape[3]; | |||||
| bool need_fallback = !is_contiguous || !dtype_same || | |||||
| !interp_supported || | |||||
| (!nchw44_enable && !nchw88_enable); | |||||
| if (need_fallback) { | |||||
| fallback::ResizeImpl::exec(src, dst, workspace); | |||||
| } else if (nchw44_enable) { | |||||
| auto kern_param = KernParam<float>::from_tensors( | |||||
| param().format, param().imode, src, dst, workspace); | |||||
| if (is_upsample2) { | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
| kern_nearest_upsample2_pack_simd_width(src, dst)); | |||||
| } else { | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw44_fp32(kern_param)); | |||||
| } | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| } else if (nchw88_enable) { | |||||
| auto kern_param = KernParam<dt_float16>::from_tensors( | |||||
| param().format, param().imode, src, dst, workspace); | |||||
| if (is_upsample2) { | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
| kern_nearest_upsample2_pack_simd_width(src, dst)); | |||||
| } else { | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw88_fp16(kern_param)); | |||||
| } | |||||
| #endif | |||||
| } else { | |||||
| fallback::ResizeImpl::exec(src, dst, workspace); | |||||
| } | |||||
| } else if (param().format == param::Resize::Format::NCHW || | |||||
| (src.layout[3] != 1 && src.layout[3] != 3) || | |||||
| !is_nhwc_contig_wc(src.layout)) { | |||||
| fallback::ResizeImpl::exec(src, dst, workspace); | fallback::ResizeImpl::exec(src, dst, workspace); | ||||
| } else { | } else { | ||||
| megdnn_assert(param().format == param::Resize::Format::NHWC, | megdnn_assert(param().format == param::Resize::Format::NHWC, | ||||
| @@ -30,4 +81,143 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
| } | } | ||||
| } | } | ||||
| template <typename ctype> | |||||
| void ResizeImpl::kern_nchw44_fp32(const KernParam<ctype>& kern_param) { | |||||
| UNPACK_RESIZE_FWD_KERN_PARAM(kern_param); | |||||
| float scale_h = static_cast<float>(OH) / IH; | |||||
| float scale_w = static_cast<float>(OW) / IW; | |||||
| for (size_t n = 0; n < N; ++n) { | |||||
| for (size_t c = 0; c < C / 4; ++c) { | |||||
| for (size_t oh = 0; oh < OH; ++oh) { | |||||
| for (size_t ow = 0; ow < OW; ++ow) { | |||||
| int ih0, ih1, iw0, iw1; | |||||
| float ah0, ah1, aw0, aw1; | |||||
| std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( | |||||
| kern_param.imode, scale_h, IH, oh); | |||||
| std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( | |||||
| kern_param.imode, scale_w, IW, ow); | |||||
| #define SRC_ADDRESS(ih, iw) \ | |||||
| (sptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 4) | |||||
| #define DST_ADDRESS(oh, ow) \ | |||||
| (dptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 4) | |||||
| float32x4_t r0 = vld1q_f32(SRC_ADDRESS(ih0, iw0)); | |||||
| float32_t a0 = ah0 * aw0; | |||||
| float32x4_t r1 = vld1q_f32(SRC_ADDRESS(ih0, iw1)); | |||||
| float32_t a1 = ah0 * aw1; | |||||
| float32x4_t r2 = vld1q_f32(SRC_ADDRESS(ih1, iw0)); | |||||
| float32_t a2 = ah1 * aw0; | |||||
| float32x4_t r3 = vld1q_f32(SRC_ADDRESS(ih1, iw1)); | |||||
| float32_t a3 = ah1 * aw1; | |||||
| r0 = vmulq_n_f32(r0, a0); | |||||
| #if defined(__ARM_FEATURE_FMA) && defined(__aarch64__) | |||||
| r0 = vfmaq_n_f32(r0, r1, a1); | |||||
| r0 = vfmaq_n_f32(r0, r2, a2); | |||||
| r0 = vfmaq_n_f32(r0, r3, a3); | |||||
| #else | |||||
| r0 = vaddq_f32(r0, vmulq_n_f32(r1, a1)); | |||||
| r0 = vaddq_f32(r0, vmulq_n_f32(r2, a2)); | |||||
| r0 = vaddq_f32(r0, vmulq_n_f32(r3, a3)); | |||||
| #endif | |||||
| vst1q_f32(DST_ADDRESS(oh, ow), r0); | |||||
| #undef SRC_ADDRESS | |||||
| #undef DST_ADDRESS | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| template <typename ctype> | |||||
| void ResizeImpl::kern_nchw88_fp16(const KernParam<ctype>& kern_param) { | |||||
| UNPACK_RESIZE_FWD_KERN_PARAM(kern_param); | |||||
| float scale_h = static_cast<float>(OH) / IH; | |||||
| float scale_w = static_cast<float>(OW) / IW; | |||||
| const float16_t* src_ptr = reinterpret_cast<float16_t*>(sptr); | |||||
| float16_t* dst_ptr = reinterpret_cast<float16_t*>(dptr); | |||||
| for (size_t n = 0; n < N; ++n) { | |||||
| for (size_t c = 0; c < C / 8; ++c) { | |||||
| for (size_t oh = 0; oh < OH; ++oh) { | |||||
| for (size_t ow = 0; ow < OW; ++ow) { | |||||
| int ih0, ih1, iw0, iw1; | |||||
| float ah0, ah1, aw0, aw1; | |||||
| std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( | |||||
| kern_param.imode, scale_h, IH, oh); | |||||
| std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( | |||||
| kern_param.imode, scale_w, IW, ow); | |||||
| #define SRC_ADDRESS(ih, iw) \ | |||||
| (src_ptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 8) | |||||
| #define DST_ADDRESS(oh, ow) \ | |||||
| (dst_ptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 8) | |||||
| float16x8_t r0 = vld1q_f16(SRC_ADDRESS(ih0, iw0)); | |||||
| float32_t a0 = ah0 * aw0; | |||||
| float16x8_t r1 = vld1q_f16(SRC_ADDRESS(ih0, iw1)); | |||||
| float32_t a1 = ah0 * aw1; | |||||
| float16x8_t r2 = vld1q_f16(SRC_ADDRESS(ih1, iw0)); | |||||
| float32_t a2 = ah1 * aw0; | |||||
| float16x8_t r3 = vld1q_f16(SRC_ADDRESS(ih1, iw1)); | |||||
| float32_t a3 = ah1 * aw1; | |||||
| r0 = vmulq_n_f16(r0, a0); | |||||
| #if defined(__ARM_FEATURE_FMA) && defined(__aarch64__) | |||||
| r0 = vfmaq_n_f16(r0, r1, a1); | |||||
| r0 = vfmaq_n_f16(r0, r2, a2); | |||||
| r0 = vfmaq_n_f16(r0, r3, a3); | |||||
| #else | |||||
| r0 = vaddq_f16(r0, vmulq_n_f16(r1, a1)); | |||||
| r0 = vaddq_f16(r0, vmulq_n_f16(r2, a2)); | |||||
| r0 = vaddq_f16(r0, vmulq_n_f16(r3, a3)); | |||||
| #endif | |||||
| vst1q_f16(DST_ADDRESS(oh, ow), r0); | |||||
| #undef SRC_ADDRESS | |||||
| #undef DST_ADDRESS | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #endif | |||||
| void ResizeImpl::kern_nearest_upsample2_pack_simd_width( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| const uint8_t* src_ptr = reinterpret_cast<uint8_t*>(src.raw_ptr); | |||||
| uint8_t* dst_ptr = reinterpret_cast<uint8_t*>(dst.raw_ptr); | |||||
| size_t S = 2; | |||||
| size_t N = src.layout.shape[0]; | |||||
| size_t IC = src.layout.shape[1]; | |||||
| size_t IH = src.layout.shape[2]; | |||||
| size_t IW = src.layout.shape[3]; | |||||
| size_t OH = dst.layout.shape[2]; | |||||
| size_t OW = dst.layout.shape[3]; | |||||
| for (size_t i = 0; i < N * IC; ++i) { | |||||
| for (size_t ih = 0; ih < IH; ++ih) { | |||||
| for (size_t iw = 0; iw < IW; ++iw) { | |||||
| size_t oh = ih * S; | |||||
| size_t ow = iw * S; | |||||
| uint8x16_t r0 = vld1q_u8(src_ptr + i * IH * IW * 16 + | |||||
| ih * IW * 16 + iw * 16); | |||||
| for (size_t fh = 0; fh < S; ++fh) { | |||||
| for (size_t fw = 0; fw < S; ++fw) { | |||||
| vst1q_u8(dst_ptr + i * OH * OW * 16 + | |||||
| (oh + fh) * OW * 16 + (ow + fw) * 16, | |||||
| r0); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| @@ -25,6 +26,16 @@ public: | |||||
| const TensorLayout&) override { | const TensorLayout&) override { | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| private: | |||||
| template <typename ctype> | |||||
| void kern_nchw44_fp32(const KernParam<ctype>& kern_param); | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| template <typename ctype> | |||||
| void kern_nchw88_fp16(const KernParam<ctype>& kern_param); | |||||
| #endif | |||||
| void kern_nearest_upsample2_pack_simd_width(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_out dst); | |||||
| }; | }; | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -40,11 +40,29 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src, | |||||
| megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8); | megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8); | ||||
| megdnn_assert(src.shape[4] == 4); | megdnn_assert(src.shape[4] == 4); | ||||
| megdnn_assert(dst.shape[4] == 4); | megdnn_assert(dst.shape[4] == 4); | ||||
| } else if (param().format == Param::Format::NCHW44) { | |||||
| megdnn_assert(src.ndim == 5); | |||||
| megdnn_assert(src.shape[4] == 4); | |||||
| megdnn_assert(dst.shape[4] == 4); | |||||
| megdnn_assert(param().imode == | |||||
| param::Resize::InterpolationMode::INTER_LINEAR || | |||||
| param().imode == | |||||
| param::Resize::InterpolationMode::INTER_NEAREST); | |||||
| } else if (param().format == Param::Format::NCHW88) { | |||||
| megdnn_assert(src.ndim == 5); | |||||
| megdnn_assert(src.shape[4] == 8); | |||||
| megdnn_assert(dst.shape[4] == 8); | |||||
| megdnn_assert(param().imode == | |||||
| param::Resize::InterpolationMode::INTER_LINEAR || | |||||
| param().imode == | |||||
| param::Resize::InterpolationMode::INTER_NEAREST); | |||||
| } else { | } else { | ||||
| megdnn_assert(param().format == Param::Format::NHWCD4, | megdnn_assert(param().format == Param::Format::NHWCD4, | ||||
| "invalid resize tensor format"); | "invalid resize tensor format"); | ||||
| megdnn_assert(param().imode == | megdnn_assert(param().imode == | ||||
| param::Resize::InterpolationMode::INTER_LINEAR); | |||||
| param::Resize::InterpolationMode::INTER_LINEAR || | |||||
| param().imode == | |||||
| param::Resize::InterpolationMode::INTER_NEAREST); | |||||
| megdnn_assert(dst.shape[2] == src.shape[2], "%s", errmsg().c_str()); | megdnn_assert(dst.shape[2] == src.shape[2], "%s", errmsg().c_str()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -67,24 +85,39 @@ void ResizeBackward::check_exec(const TensorLayout& diff, | |||||
| "Backward resize only supports Float32 and NCHW."); | "Backward resize only supports Float32 and NCHW."); | ||||
| } | } | ||||
| std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size, | |||||
| int idx, bool cubic) { | |||||
| //! copy from resize_cv.cpp | |||||
| std::pair<float, int> ResizeBase::get_cubic_coord(float scale, int idx) { | |||||
| float alpha = (idx + 0.5f) / scale - 0.5f; | float alpha = (idx + 0.5f) / scale - 0.5f; | ||||
| int origin_idx = static_cast<int>(floor(alpha)); | int origin_idx = static_cast<int>(floor(alpha)); | ||||
| alpha -= origin_idx; | alpha -= origin_idx; | ||||
| if (!cubic) { | |||||
| if (origin_idx < 0) { | |||||
| origin_idx = 0; | |||||
| alpha = 0; | |||||
| } else if (origin_idx + 1 >= size) { | |||||
| origin_idx = size - 2; | |||||
| alpha = 1; | |||||
| } | |||||
| } | |||||
| return {alpha, origin_idx}; | return {alpha, origin_idx}; | ||||
| } | } | ||||
| std::tuple<float, int, float, int> ResizeBase::get_nearest_linear_coord( | |||||
| InterpolationMode imode, float scale, int size, int idx) { | |||||
| if (size == 1) { | |||||
| return std::make_tuple(1.0f, 0, 0.0f, 0); | |||||
| } | |||||
| float alpha = (idx + 0.5f) / scale - 0.5f; | |||||
| int origin_idx = static_cast<int>(floor(alpha)); | |||||
| alpha -= origin_idx; | |||||
| if (imode == InterpolationMode::INTER_NEAREST) { | |||||
| origin_idx = get_nearest_src(scale, size, idx); | |||||
| alpha = 0; | |||||
| } | |||||
| if (origin_idx < 0) { | |||||
| origin_idx = 0; | |||||
| alpha = 0; | |||||
| } else if (origin_idx + 1 >= size) { | |||||
| origin_idx = size - 2; | |||||
| alpha = 1; | |||||
| } | |||||
| return std::make_tuple(1 - alpha, origin_idx, alpha, origin_idx + 1); | |||||
| } | |||||
| int ResizeBase::get_nearest_src(float scale, int size, int idx) { | int ResizeBase::get_nearest_src(float scale, int size, int idx) { | ||||
| return std::min(static_cast<int>(idx / scale), size - 1); | return std::min(static_cast<int>(idx / scale), size - 1); | ||||
| } | } | ||||
| @@ -6,13 +6,14 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "src/fallback/resize/opr_impl.h" | #include "src/fallback/resize/opr_impl.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/fallback/handle.h" | |||||
| #include "src/common/rounding_converter.cuh" | #include "src/common/rounding_converter.cuh" | ||||
| #include "src/fallback/handle.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace fallback; | using namespace fallback; | ||||
| @@ -30,37 +31,36 @@ void ResizeImpl::kern_fallback(const KernParam<ctype>& kern_param) { | |||||
| float scale_h = static_cast<float>(OH) / IH; | float scale_h = static_cast<float>(OH) / IH; | ||||
| float scale_w = static_cast<float>(OW) / IW; | float scale_w = static_cast<float>(OW) / IW; | ||||
| auto build_table = [this](float scale, int isize, | |||||
| int osize) -> std::vector<std::pair<float, int>> { | |||||
| std::vector<std::pair<float, int>> table; | |||||
| rep(i, osize) { table.push_back(get_origin_coord(scale, isize, i)); } | |||||
| auto build_table = [this](InterpolationMode imode, float scale, int isize, | |||||
| int osize) { | |||||
| std::vector<std::tuple<float, int, float, int>> table; | |||||
| rep(i, osize) { | |||||
| table.push_back(get_nearest_linear_coord(imode, scale, isize, i)); | |||||
| } | |||||
| return table; | return table; | ||||
| }; | }; | ||||
| auto table_h = build_table(scale_h, IH, OH); | |||||
| auto table_w = build_table(scale_w, IW, OW); | |||||
| auto table_h = build_table(kern_param.imode, scale_h, IH, OH); | |||||
| auto table_w = build_table(kern_param.imode, scale_w, IW, OW); | |||||
| rep(n, N) { | rep(n, N) { | ||||
| rep(c, static_cast<int>(C)) { | rep(c, static_cast<int>(C)) { | ||||
| rep(oh, OH) { | rep(oh, OH) { | ||||
| auto coord_h = table_h[oh]; | |||||
| float alphah = coord_h.first; | |||||
| int ih0 = coord_h.second; | |||||
| int ih1 = ih0 + 1; | |||||
| float ah0, ah1, aw0, aw1; | |||||
| int ih0, ih1, iw0, iw1; | |||||
| std::tie(ah0, ih0, ah1, ih1) = table_h[oh]; | |||||
| rep(ow, OW) { | rep(ow, OW) { | ||||
| auto coord_w = table_w[ow]; | |||||
| float alphaw = coord_w.first; | |||||
| int iw0 = coord_w.second; | |||||
| int iw1 = iw0 + 1; | |||||
| std::tie(aw0, iw0, aw1, iw1) = table_w[ow]; | |||||
| dptr[c * OH * OW + oh * OW + ow] = output_converter( | dptr[c * OH * OW + oh * OW + ow] = output_converter( | ||||
| sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * | |||||
| (1.0f - alphaw) * (1.0f - alphah) + | |||||
| sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * | |||||
| alphaw * (1.0f - alphah) + | |||||
| sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * | |||||
| (1.0f - alphaw) * alphah + | |||||
| sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * | |||||
| alphaw * alphah); | |||||
| sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * ah0 * | |||||
| aw0 + | |||||
| sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * ah0 * | |||||
| aw1 + | |||||
| sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * ah1 * | |||||
| aw0 + | |||||
| sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * ah1 * | |||||
| aw1); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -76,35 +76,31 @@ void ResizeImpl::kern_fallback_nhwc(const KernParam<ctype>& kern_param) { | |||||
| float scale_h = static_cast<float>(OH) / IH; | float scale_h = static_cast<float>(OH) / IH; | ||||
| float scale_w = static_cast<float>(OW) / IW; | float scale_w = static_cast<float>(OW) / IW; | ||||
| auto build_table = [this](float scale, int isize, | |||||
| int osize) -> std::vector<std::pair<float, int>> { | |||||
| std::vector<std::pair<float, int>> table; | |||||
| rep(i, osize) { table.push_back(get_origin_coord(scale, isize, i)); } | |||||
| auto build_table = [this](InterpolationMode imode, float scale, int isize, | |||||
| int osize) { | |||||
| std::vector<std::tuple<float, int, float, int>> table; | |||||
| rep(i, osize) { | |||||
| table.push_back(get_nearest_linear_coord(imode, scale, isize, i)); | |||||
| } | |||||
| return table; | return table; | ||||
| }; | }; | ||||
| auto table_h = build_table(scale_h, IH, OH); | |||||
| auto table_w = build_table(scale_w, IW, OW); | |||||
| auto table_h = build_table(kern_param.imode, scale_h, IH, OH); | |||||
| auto table_w = build_table(kern_param.imode, scale_w, IW, OW); | |||||
| rep(n, N) { | rep(n, N) { | ||||
| rep(oh, OH) { | rep(oh, OH) { | ||||
| auto coord_h = table_h[oh]; | |||||
| float alphah = coord_h.first; | |||||
| int ih0 = coord_h.second; | |||||
| int ih1 = ih0 + 1; | |||||
| float ah0, ah1, aw0, aw1; | |||||
| int ih0, ih1, iw0, iw1; | |||||
| std::tie(ah0, ih0, ah1, ih1) = table_h[oh]; | |||||
| rep(ow, OW) { | rep(ow, OW) { | ||||
| auto coord_w = table_w[ow]; | |||||
| float alphaw = coord_w.first; | |||||
| int iw0 = coord_w.second; | |||||
| int iw1 = iw0 + 1; | |||||
| std::tie(aw0, iw0, aw1, iw1) = table_w[ow]; | |||||
| rep(c, C) { | rep(c, C) { | ||||
| dptr[(oh * OW + ow) * C + c] = output_converter( | dptr[(oh * OW + ow) * C + c] = output_converter( | ||||
| sptr[(ih0 * IW + iw0) * C + c] * (1.0f - alphaw) * | |||||
| (1.0f - alphah) + | |||||
| sptr[(ih0 * IW + iw1) * C + c] * alphaw * | |||||
| (1.0f - alphah) + | |||||
| sptr[(ih1 * IW + iw0) * C + c] * (1.0f - alphaw) * | |||||
| alphah + | |||||
| sptr[(ih1 * IW + iw1) * C + c] * alphaw * alphah); | |||||
| sptr[(ih0 * IW + iw0) * C + c] * ah0 * aw0 + | |||||
| sptr[(ih0 * IW + iw1) * C + c] * ah0 * aw1 + | |||||
| sptr[(ih1 * IW + iw0) * C + c] * ah1 * aw0 + | |||||
| sptr[(ih1 * IW + iw1) * C + c] * ah1 * aw1); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -117,6 +113,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
| _megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
| check_exec(src.layout, dst.layout, workspace.size); | check_exec(src.layout, dst.layout, workspace.size); | ||||
| if (param().format == param::Resize::Format::NCHW4 || | if (param().format == param::Resize::Format::NCHW4 || | ||||
| param().format == param::Resize::Format::NCHW44 || | |||||
| param().format == param::Resize::Format::NCHW88 || | |||||
| (param().format == param::Resize::Format::NCHW && | (param().format == param::Resize::Format::NCHW && | ||||
| param().imode != param::Resize::InterpolationMode::INTER_LINEAR)) { | param().imode != param::Resize::InterpolationMode::INTER_LINEAR)) { | ||||
| naive::ResizeImpl::exec(src, dst, workspace); | naive::ResizeImpl::exec(src, dst, workspace); | ||||
| @@ -125,12 +123,12 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
| if ((param().format == param::Resize::Format::NCHW || | if ((param().format == param::Resize::Format::NCHW || | ||||
| (src.layout[3] != 1 && src.layout[3] != 3)) || | (src.layout[3] != 1 && src.layout[3] != 3)) || | ||||
| (param().imode == param::Resize::InterpolationMode::LINEAR)) { | (param().imode == param::Resize::InterpolationMode::LINEAR)) { | ||||
| #define cb(dt, ct) \ | |||||
| case DTypeTrait<dt>::enumv: { \ | |||||
| auto kparam = KernParam<ct>::from_tensors(param().format, src, dst, \ | |||||
| workspace); \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(kern_fallback(kparam)); \ | |||||
| return; \ | |||||
| #define cb(dt, ct) \ | |||||
| case DTypeTrait<dt>::enumv: { \ | |||||
| auto kparam = KernParam<ct>::from_tensors( \ | |||||
| param().format, param().imode, src, dst, workspace); \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(kern_fallback(kparam)); \ | |||||
| return; \ | |||||
| } | } | ||||
| switch (src.layout.dtype.enumv()) { | switch (src.layout.dtype.enumv()) { | ||||
| @@ -141,10 +139,9 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
| cb(dtype::Uint8, uint8_t); | cb(dtype::Uint8, uint8_t); | ||||
| cb(dtype::Quantized8Asymm, uint8_t); | cb(dtype::Quantized8Asymm, uint8_t); | ||||
| default: | default: | ||||
| megdnn_throw( | |||||
| ssprintf("Unsupported input DType in Resize: %s", | |||||
| src.layout.dtype.name()) | |||||
| .c_str()); | |||||
| megdnn_throw(ssprintf("Unsupported input DType in Resize: %s", | |||||
| src.layout.dtype.name()) | |||||
| .c_str()); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "src/naive/resize/opr_impl.h" | #include "src/naive/resize/opr_impl.h" | ||||
| @@ -27,10 +28,11 @@ using namespace resize; | |||||
| template <typename ctype> | template <typename ctype> | ||||
| ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors( | ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors( | ||||
| Format format, _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) { | |||||
| Format format, InterpolationMode imode, _megdnn_tensor_in src, | |||||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
| KernParam<ctype> ret; | KernParam<ctype> ret; | ||||
| ret.format = format; | ret.format = format; | ||||
| ret.imode = imode; | |||||
| ret.n = src.layout.shape[0]; | ret.n = src.layout.shape[0]; | ||||
| if (format == Format::NCHW) { | if (format == Format::NCHW) { | ||||
| ret.c = src.layout.shape[1]; | ret.c = src.layout.shape[1]; | ||||
| @@ -54,6 +56,18 @@ ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors( | |||||
| ret.iw = src.layout.shape[3]; | ret.iw = src.layout.shape[3]; | ||||
| ret.oh = dst.layout.shape[2]; | ret.oh = dst.layout.shape[2]; | ||||
| ret.ow = dst.layout.shape[3]; | ret.ow = dst.layout.shape[3]; | ||||
| } else if (format == Format::NCHW44) { | |||||
| ret.c = src.layout.shape[1] * 4; | |||||
| ret.ih = src.layout.shape[2]; | |||||
| ret.iw = src.layout.shape[3]; | |||||
| ret.oh = dst.layout.shape[2]; | |||||
| ret.ow = dst.layout.shape[3]; | |||||
| } else if (format == Format::NCHW88) { | |||||
| ret.c = src.layout.shape[1] * 8; | |||||
| ret.ih = src.layout.shape[2]; | |||||
| ret.iw = src.layout.shape[3]; | |||||
| ret.oh = dst.layout.shape[2]; | |||||
| ret.ow = dst.layout.shape[3]; | |||||
| } else { | } else { | ||||
| megdnn_assert(format == Format::NHWCD4); | megdnn_assert(format == Format::NHWCD4); | ||||
| ret.c = src.layout.shape[2] * 4; | ret.c = src.layout.shape[2] * 4; | ||||
| @@ -115,33 +129,30 @@ void ResizeImpl::kern_nchw(const KernParam<ctype>& kern_param, | |||||
| break; | break; | ||||
| } | } | ||||
| case InterpolationMode::INTER_LINEAR: { | case InterpolationMode::INTER_LINEAR: { | ||||
| auto coord_h = get_origin_coord(scale_h, IH, oh); | |||||
| auto coord_w = get_origin_coord(scale_w, IW, ow); | |||||
| float alphah = coord_h.first; | |||||
| float alphaw = coord_w.first; | |||||
| int ih0, ih1, iw0, iw1; | |||||
| float ah0, ah1, aw0, aw1; | |||||
| int ih0 = coord_h.second; | |||||
| int ih1 = ih0 + 1; | |||||
| int iw0 = coord_w.second; | |||||
| int iw1 = iw0 + 1; | |||||
| std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( | |||||
| kern_param.imode, scale_h, IH, oh); | |||||
| std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( | |||||
| kern_param.imode, scale_w, IW, ow); | |||||
| rep(c, static_cast<int>(C)) { | rep(c, static_cast<int>(C)) { | ||||
| dptr[c * OH * OW + oh * OW + ow] = output_converter( | dptr[c * OH * OW + oh * OW + ow] = output_converter( | ||||
| sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * | |||||
| (1.0f - alphaw) * (1.0f - alphah) + | |||||
| sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * | |||||
| alphaw * (1.0f - alphah) + | |||||
| sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * | |||||
| (1.0f - alphaw) * alphah + | |||||
| sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * | |||||
| alphaw * alphah); | |||||
| sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * ah0 * | |||||
| aw0 + | |||||
| sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * ah0 * | |||||
| aw1 + | |||||
| sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * ah1 * | |||||
| aw0 + | |||||
| sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * ah1 * | |||||
| aw1); | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| case InterpolationMode::INTER_CUBIC: { | case InterpolationMode::INTER_CUBIC: { | ||||
| auto coord_h = get_origin_coord(scale_h, IH, oh, true); | |||||
| auto coord_w = get_origin_coord(scale_w, IW, ow, true); | |||||
| auto coord_h = get_cubic_coord(scale_h, oh); | |||||
| auto coord_w = get_cubic_coord(scale_w, ow); | |||||
| float alphah = coord_h.first; | float alphah = coord_h.first; | ||||
| float alphaw = coord_w.first; | float alphaw = coord_w.first; | ||||
| @@ -193,7 +204,19 @@ void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) { | |||||
| return; | return; | ||||
| } else if (kern_param.format == Format::NCHW4) { | } else if (kern_param.format == Format::NCHW4) { | ||||
| MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(2)) { | MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(2)) { | ||||
| kern_naive_nchw4(kern_param); | |||||
| kern_naive_nchwx<ctype, 4>(kern_param); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| return; | |||||
| } else if (kern_param.format == Format::NCHW44) { | |||||
| MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(3)) { | |||||
| kern_naive_nchwx<ctype, 4>(kern_param); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| return; | |||||
| } else if (kern_param.format == Format::NCHW88) { | |||||
| MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(4)) { | |||||
| kern_naive_nchwx<ctype, 8>(kern_param); | |||||
| } | } | ||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| return; | return; | ||||
| @@ -209,25 +232,20 @@ void ResizeImpl::kern_naive_nhwc(const KernParam<ctype>& kern_param) { | |||||
| rep(n, N) { | rep(n, N) { | ||||
| rep(oh, OH) rep(ow, OW) { | rep(oh, OH) rep(ow, OW) { | ||||
| auto coord_h = get_origin_coord(scale_h, IH, oh); | |||||
| auto coord_w = get_origin_coord(scale_w, IW, ow); | |||||
| int ih0, ih1, iw0, iw1; | |||||
| float ah0, ah1, aw0, aw1; | |||||
| float alphah = coord_h.first; | |||||
| float alphaw = coord_w.first; | |||||
| std::tie(ah0, ih0, ah1, ih1) = | |||||
| get_nearest_linear_coord(kern_param.imode, scale_h, IH, oh); | |||||
| std::tie(aw0, iw0, aw1, iw1) = | |||||
| get_nearest_linear_coord(kern_param.imode, scale_w, IW, ow); | |||||
| int ih0 = coord_h.second; | |||||
| int ih1 = ih0 + 1; | |||||
| int iw0 = coord_w.second; | |||||
| int iw1 = iw0 + 1; | |||||
| rep(c, C) { | rep(c, C) { | ||||
| dptr[(oh * OW + ow) * C + c] = output_converter( | dptr[(oh * OW + ow) * C + c] = output_converter( | ||||
| sptr[(ih0 * IW + iw0) * C + c] * (1.0f - alphaw) * | |||||
| (1.0f - alphah) + | |||||
| sptr[(ih0 * IW + iw1) * C + c] * alphaw * | |||||
| (1.0f - alphah) + | |||||
| sptr[(ih1 * IW + iw0) * C + c] * (1.0f - alphaw) * | |||||
| alphah + | |||||
| sptr[(ih1 * IW + iw1) * C + c] * alphaw * alphah); | |||||
| sptr[(ih0 * IW + iw0) * C + c] * ah0 * aw0 + | |||||
| sptr[(ih0 * IW + iw1) * C + c] * ah0 * aw1 + | |||||
| sptr[(ih1 * IW + iw0) * C + c] * ah1 * aw0 + | |||||
| sptr[(ih1 * IW + iw1) * C + c] * ah1 * aw1); | |||||
| } | } | ||||
| } | } | ||||
| sptr += C * IH * IW; | sptr += C * IH * IW; | ||||
| @@ -251,26 +269,20 @@ void ResizeImpl::kern_naive_nhwcd4(const KernParam<ctype>& kern_param) { | |||||
| rep(n, N) { | rep(n, N) { | ||||
| rep(oh, OH) rep(ow, OW) { | rep(oh, OH) rep(ow, OW) { | ||||
| auto coord_h = get_origin_coord(scale_h, IH, oh); | |||||
| auto coord_w = get_origin_coord(scale_w, IW, ow); | |||||
| int ih0, ih1, iw0, iw1; | |||||
| float ah0, ah1, aw0, aw1; | |||||
| float alphah = coord_h.first; | |||||
| float alphaw = coord_w.first; | |||||
| std::tie(ah0, ih0, ah1, ih1) = | |||||
| get_nearest_linear_coord(kern_param.imode, scale_h, IH, oh); | |||||
| std::tie(aw0, iw0, aw1, iw1) = | |||||
| get_nearest_linear_coord(kern_param.imode, scale_w, IW, ow); | |||||
| int ih0 = coord_h.second; | |||||
| int ih1 = ih0 + 1; | |||||
| int iw0 = coord_w.second; | |||||
| int iw1 = iw0 + 1; | |||||
| rep(c, C) { | rep(c, C) { | ||||
| dptr[get_tensor_addr(oh, ow, c, OW, C)] = output_converter( | dptr[get_tensor_addr(oh, ow, c, OW, C)] = output_converter( | ||||
| sptr[get_tensor_addr(ih0, iw0, c, IW, C)] * | |||||
| (1.0f - alphaw) * (1.0f - alphah) + | |||||
| sptr[get_tensor_addr(ih0, iw1, c, IW, C)] * alphaw * | |||||
| (1.0f - alphah) + | |||||
| sptr[get_tensor_addr(ih1, iw0, c, IW, C)] * | |||||
| (1.0f - alphaw) * alphah + | |||||
| sptr[get_tensor_addr(ih1, iw1, c, IW, C)] * alphaw * | |||||
| alphah); | |||||
| sptr[get_tensor_addr(ih0, iw0, c, IW, C)] * ah0 * aw0 + | |||||
| sptr[get_tensor_addr(ih0, iw1, c, IW, C)] * ah0 * aw1 + | |||||
| sptr[get_tensor_addr(ih1, iw0, c, IW, C)] * ah1 * aw0 + | |||||
| sptr[get_tensor_addr(ih1, iw1, c, IW, C)] * ah1 * aw1); | |||||
| } | } | ||||
| } | } | ||||
| sptr += IH * (C / 4) * IW * 4; | sptr += IH * (C / 4) * IW * 4; | ||||
| @@ -278,41 +290,46 @@ void ResizeImpl::kern_naive_nhwcd4(const KernParam<ctype>& kern_param) { | |||||
| } | } | ||||
| } | } | ||||
| template <typename ctype> | |||||
| void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) { | |||||
| template <typename ctype, size_t pack_size> | |||||
| void ResizeImpl::kern_naive_nchwx(const KernParam<ctype>& kern_param) { | |||||
| UNPACK_RESIZE_FWD_KERN_PARAM(kern_param); | UNPACK_RESIZE_FWD_KERN_PARAM(kern_param); | ||||
| rounding::RoundingConverter<ctype> output_converter; | rounding::RoundingConverter<ctype> output_converter; | ||||
| float scale_h = static_cast<float>(OH) / IH; | float scale_h = static_cast<float>(OH) / IH; | ||||
| float scale_w = static_cast<float>(OW) / IW; | float scale_w = static_cast<float>(OW) / IW; | ||||
| megdnn_assert(pack_size == 4 || pack_size == 8); | |||||
| size_t log_pack_size = 2; | |||||
| if (pack_size == 8) { | |||||
| log_pack_size = 3; | |||||
| } | |||||
| auto get_tensor_addr = [&](size_t h, size_t w, size_t c, size_t H, size_t W, | auto get_tensor_addr = [&](size_t h, size_t w, size_t c, size_t H, size_t W, | ||||
| size_t C) -> size_t { | size_t C) -> size_t { | ||||
| megdnn_assert((C & 0x3) == 0); | |||||
| return (((c >> 2) * H * W + h * W + w) << 2) + (c & 0b11); | |||||
| megdnn_assert((C & (pack_size - 1)) == 0); | |||||
| return (((c >> log_pack_size) * H * W + h * W + w) << log_pack_size) + | |||||
| (c & (pack_size - 1)); | |||||
| }; | }; | ||||
| rep(n, N) { | rep(n, N) { | ||||
| rep(oh, OH) rep(ow, OW) { | rep(oh, OH) rep(ow, OW) { | ||||
| auto coord_h = get_origin_coord(scale_h, IH, oh); | |||||
| auto coord_w = get_origin_coord(scale_w, IW, ow); | |||||
| int ih0, ih1, iw0, iw1; | |||||
| float ah0, ah1, aw0, aw1; | |||||
| float alphah = coord_h.first; | |||||
| float alphaw = coord_w.first; | |||||
| std::tie(ah0, ih0, ah1, ih1) = | |||||
| get_nearest_linear_coord(kern_param.imode, scale_h, IH, oh); | |||||
| std::tie(aw0, iw0, aw1, iw1) = | |||||
| get_nearest_linear_coord(kern_param.imode, scale_w, IW, ow); | |||||
| int ih0 = coord_h.second; | |||||
| int ih1 = ih0 + 1; | |||||
| int iw0 = coord_w.second; | |||||
| int iw1 = iw0 + 1; | |||||
| rep(c, C) { | rep(c, C) { | ||||
| dptr[get_tensor_addr(oh, ow, c, OH, OW, C)] = output_converter( | dptr[get_tensor_addr(oh, ow, c, OH, OW, C)] = output_converter( | ||||
| sptr[get_tensor_addr(ih0, iw0, c, IH, IW, C)] * | |||||
| (1.0f - alphaw) * (1.0f - alphah) + | |||||
| sptr[get_tensor_addr(ih0, iw1, c, IH, IW, C)] * alphaw * | |||||
| (1.0f - alphah) + | |||||
| sptr[get_tensor_addr(ih1, iw0, c, IH, IW, C)] * | |||||
| (1.0f - alphaw) * alphah + | |||||
| sptr[get_tensor_addr(ih1, iw1, c, IH, IW, C)] * alphaw * | |||||
| alphah); | |||||
| sptr[get_tensor_addr(ih0, iw0, c, IH, IW, C)] * ah0 * | |||||
| aw0 + | |||||
| sptr[get_tensor_addr(ih0, iw1, c, IH, IW, C)] * ah0 * | |||||
| aw1 + | |||||
| sptr[get_tensor_addr(ih1, iw0, c, IH, IW, C)] * ah1 * | |||||
| aw0 + | |||||
| sptr[get_tensor_addr(ih1, iw1, c, IH, IW, C)] * ah1 * | |||||
| aw1); | |||||
| } | } | ||||
| } | } | ||||
| sptr += IH * IW * C; | sptr += IH * IW * C; | ||||
| @@ -327,8 +344,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
| #define cb(dt, ct, _midout_iv) \ | #define cb(dt, ct, _midout_iv) \ | ||||
| case DTypeTrait<dt>::enumv: { \ | case DTypeTrait<dt>::enumv: { \ | ||||
| MIDOUT_BEGIN(megdnn_naive_resize_nchw, midout_iv(_midout_iv)) { \ | MIDOUT_BEGIN(megdnn_naive_resize_nchw, midout_iv(_midout_iv)) { \ | ||||
| auto kparam = KernParam<ct>::from_tensors(param().format, src, \ | |||||
| dst, workspace); \ | |||||
| auto kparam = KernParam<ct>::from_tensors( \ | |||||
| param().format, param().imode, src, dst, workspace); \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw(kparam, param().imode)); \ | MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw(kparam, param().imode)); \ | ||||
| } \ | } \ | ||||
| MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
| @@ -356,15 +373,15 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
| if (((src.layout[3] != 1 && src.layout[3] != 3) || | if (((src.layout[3] != 1 && src.layout[3] != 3) || | ||||
| !is_nhwc_contig_wc(src.layout)) || | !is_nhwc_contig_wc(src.layout)) || | ||||
| (param().imode == param::Resize::InterpolationMode::LINEAR)) { | (param().imode == param::Resize::InterpolationMode::LINEAR)) { | ||||
| #define cb(dt, ct, _midout_iv) \ | |||||
| case DTypeTrait<dt>::enumv: { \ | |||||
| MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(_midout_iv)) { \ | |||||
| auto kparam = KernParam<ct>::from_tensors(param().format, src, \ | |||||
| dst, workspace); \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| return; \ | |||||
| #define cb(dt, ct, _midout_iv) \ | |||||
| case DTypeTrait<dt>::enumv: { \ | |||||
| MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(_midout_iv)) { \ | |||||
| auto kparam = KernParam<ct>::from_tensors( \ | |||||
| param().format, param().imode, src, dst, workspace); \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| return; \ | |||||
| } | } | ||||
| switch (src.layout.dtype.enumv()) { | switch (src.layout.dtype.enumv()) { | ||||
| @@ -409,27 +426,24 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| rep(oh, OH) rep(ow, OW) { | rep(oh, OH) rep(ow, OW) { | ||||
| switch (param().imode) { | switch (param().imode) { | ||||
| case InterpolationMode::INTER_LINEAR: { | case InterpolationMode::INTER_LINEAR: { | ||||
| auto coord_h = get_origin_coord(scale_h, IH, oh); | |||||
| auto coord_w = get_origin_coord(scale_w, IW, ow); | |||||
| float alphah = coord_h.first; | |||||
| float alphaw = coord_w.first; | |||||
| int ih0, ih1, iw0, iw1; | |||||
| float ah0, ah1, aw0, aw1; | |||||
| int ih0 = coord_h.second; | |||||
| int ih1 = ih0 + 1; | |||||
| int iw0 = coord_w.second; | |||||
| int iw1 = iw0 + 1; | |||||
| std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( | |||||
| param().imode, scale_h, IH, oh); | |||||
| std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( | |||||
| param().imode, scale_w, IW, ow); | |||||
| rep(c, C) { | rep(c, C) { | ||||
| float hidden = hptr[c * OH * OW + oh * OW + ow]; | float hidden = hptr[c * OH * OW + oh * OW + ow]; | ||||
| sptr[c * IH * IW + ih0 * IW + iw0] += | sptr[c * IH * IW + ih0 * IW + iw0] += | ||||
| (1.0f - alphaw) * (1.0f - alphah) * hidden; | |||||
| ah0 * aw0 * hidden; | |||||
| sptr[c * IH * IW + ih1 * IW + iw0] += | sptr[c * IH * IW + ih1 * IW + iw0] += | ||||
| (1.0f - alphaw) * alphah * hidden; | |||||
| ah1 * aw0 * hidden; | |||||
| sptr[c * IH * IW + ih0 * IW + iw1] += | sptr[c * IH * IW + ih0 * IW + iw1] += | ||||
| alphaw * (1.0f - alphah) * hidden; | |||||
| ah0 * aw1 * hidden; | |||||
| sptr[c * IH * IW + ih1 * IW + iw1] += | sptr[c * IH * IW + ih1 * IW + iw1] += | ||||
| alphaw * alphah * hidden; | |||||
| ah1 * aw1 * hidden; | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -443,8 +457,8 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| break; | break; | ||||
| } | } | ||||
| case InterpolationMode::INTER_CUBIC: { | case InterpolationMode::INTER_CUBIC: { | ||||
| auto coord_h = get_origin_coord(scale_h, IH, oh, true); | |||||
| auto coord_w = get_origin_coord(scale_w, IW, ow, true); | |||||
| auto coord_h = get_cubic_coord(scale_h, oh); | |||||
| auto coord_w = get_cubic_coord(scale_w, ow); | |||||
| float alphah = coord_h.first; | float alphah = coord_h.first; | ||||
| float alphaw = coord_w.first; | float alphaw = coord_w.first; | ||||
| @@ -460,7 +474,8 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| rep(kh, ksize) { | rep(kh, ksize) { | ||||
| int h = saturate<int, int>(ih0 + kh, 0, IH - 1); | int h = saturate<int, int>(ih0 + kh, 0, IH - 1); | ||||
| rep(kw, ksize) { | rep(kw, ksize) { | ||||
| int w = saturate<int, int>(iw0 + kw, 0, IW - 1); | |||||
| int w = saturate<int, int>(iw0 + kw, 0, | |||||
| IW - 1); | |||||
| sptr[c * IH * IW + h * IW + w] += | sptr[c * IH * IW + h * IW + w] += | ||||
| hptr[c * OH * OW + oh * OW + ow] * | hptr[c * OH * OW + oh * OW + ow] * | ||||
| h_coeff[kh] * w_coeff[kw]; | h_coeff[kh] * w_coeff[kw]; | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -19,15 +20,18 @@ namespace naive { | |||||
| class ResizeImpl : public Resize { | class ResizeImpl : public Resize { | ||||
| public: | public: | ||||
| using Format = Param::Format; | using Format = Param::Format; | ||||
| using InterpolationMode = Param::InterpolationMode; | |||||
| template <typename ctype> | template <typename ctype> | ||||
| struct KernParam { | struct KernParam { | ||||
| Format format; | Format format; | ||||
| InterpolationMode imode; | |||||
| size_t n, c, ih, iw, oh, ow; | size_t n, c, ih, iw, oh, ow; | ||||
| ptrdiff_t s_in, s_ic, s_ih, s_iw; | ptrdiff_t s_in, s_ic, s_ih, s_iw; | ||||
| ctype *sptr, *dptr; | ctype *sptr, *dptr; | ||||
| Workspace workspace; | Workspace workspace; | ||||
| static KernParam from_tensors(Format format, _megdnn_tensor_in src, | |||||
| static KernParam from_tensors(Format format, InterpolationMode imode, | |||||
| _megdnn_tensor_in src, | |||||
| _megdnn_tensor_out dst, | _megdnn_tensor_out dst, | ||||
| _megdnn_workspace workspace); | _megdnn_workspace workspace); | ||||
| }; | }; | ||||
| @@ -41,6 +45,7 @@ public: | |||||
| const TensorLayout&) override { | const TensorLayout&) override { | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| private: | private: | ||||
| // ctype: C type of input data type. | // ctype: C type of input data type. | ||||
| template <typename ctype> | template <typename ctype> | ||||
| @@ -55,8 +60,8 @@ private: | |||||
| template <typename ctype> | template <typename ctype> | ||||
| void kern_naive_nhwcd4(const KernParam<ctype>& kern_param); | void kern_naive_nhwcd4(const KernParam<ctype>& kern_param); | ||||
| template <typename ctype> | |||||
| void kern_naive_nchw4(const KernParam<ctype>& kern_param); | |||||
| template <typename ctype, size_t pack_size> | |||||
| void kern_naive_nchwx(const KernParam<ctype>& kern_param); | |||||
| }; // class ResizeImpl | }; // class ResizeImpl | ||||
| @@ -65,15 +70,15 @@ private: | |||||
| ctype* __restrict sptr = p.sptr; \ | ctype* __restrict sptr = p.sptr; \ | ||||
| ctype* __restrict dptr = p.dptr; | ctype* __restrict dptr = p.dptr; | ||||
| #define UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(p) \ | |||||
| UNPACK_RESIZE_FWD_KERN_PARAM(p) \ | |||||
| #define UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(p) \ | |||||
| UNPACK_RESIZE_FWD_KERN_PARAM(p) \ | |||||
| auto S_IN = p.s_in, S_IC = p.s_ic, S_IH = p.s_ih, S_IW = p.s_iw; | auto S_IN = p.s_in, S_IC = p.s_ic, S_IH = p.s_ih, S_IW = p.s_iw; | ||||
| class ResizeBackwardImpl: public ResizeBackward { | |||||
| class ResizeBackwardImpl : public ResizeBackward { | |||||
| public: | public: | ||||
| using ResizeBackward::ResizeBackward; | using ResizeBackward::ResizeBackward; | ||||
| void exec(_megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
| void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, | ||||
| const TensorLayout&) override { | const TensorLayout&) override { | ||||
| return 0; | return 0; | ||||
| @@ -6,40 +6,66 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "test/arm_common/fixture.h" | |||||
| #include "test/common/resize.h" | #include "test/common/resize.h" | ||||
| #include "test/arm_common/fixture.h" | |||||
| #include "test/common/checker.h" | #include "test/common/checker.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace test { | namespace test { | ||||
| TEST_F(ARM_COMMON, RESIZE_CV) | |||||
| { | |||||
| TEST_F(ARM_COMMON, RESIZE_CV) { | |||||
| using namespace resize; | using namespace resize; | ||||
| std::vector<TestArg> args = get_cv_args(); | std::vector<TestArg> args = get_cv_args(); | ||||
| Checker<Resize> checker(handle()); | Checker<Resize> checker(handle()); | ||||
| for (auto &&arg: args) { | |||||
| for (auto&& arg : args) { | |||||
| checker.set_param(arg.param) | checker.set_param(arg.param) | ||||
| .set_epsilon(1 + 1e-3) | |||||
| .set_dtype(0, dtype::Uint8()) | |||||
| .set_dtype(1, dtype::Uint8()) | |||||
| .execs({arg.src, arg.dst}); | |||||
| .set_epsilon(1 + 1e-3) | |||||
| .set_dtype(0, dtype::Uint8()) | |||||
| .set_dtype(1, dtype::Uint8()) | |||||
| .execs({arg.src, arg.dst}); | |||||
| } | } | ||||
| for (auto &&arg: args) { | |||||
| for (auto&& arg : args) { | |||||
| checker.set_param(arg.param) | checker.set_param(arg.param) | ||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .execs({arg.src, arg.dst}); | |||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .execs({arg.src, arg.dst}); | |||||
| } | } | ||||
| } | |||||
| TEST_F(ARM_COMMON, RESIZE_NCHW44) { | |||||
| using namespace resize; | |||||
| std::vector<TestArg> args = get_nchw44_args(); | |||||
| Checker<Resize> checker(handle()); | |||||
| for (auto&& arg : args) { | |||||
| checker.set_param(arg.param) | |||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .execs({arg.src, arg.dst}); | |||||
| } | |||||
| } | |||||
| TEST_F(ARM_COMMON, RESIZE_NCHW88) { | |||||
| using namespace resize; | |||||
| std::vector<TestArg> args = get_nchw88_args(); | |||||
| Checker<Resize> checker(handle()); | |||||
| for (auto&& arg : args) { | |||||
| checker.set_param(arg.param) | |||||
| .set_epsilon(0.01) | |||||
| .set_dtype(0, dtype::Float16()) | |||||
| .set_dtype(1, dtype::Float16()) | |||||
| .execs({arg.src, arg.dst}); | |||||
| } | |||||
| } | } | ||||
| } // namespace test | |||||
| } // namespace megdnn | |||||
| } // namespace test | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -6,12 +6,13 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "megdnn/basic_types.h" | |||||
| #include <iostream> | #include <iostream> | ||||
| #include "megdnn/basic_types.h" | |||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "./rng.h" | #include "./rng.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -68,13 +69,15 @@ static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) { | |||||
| std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
| set_nchw_args(args); | set_nchw_args(args); | ||||
| if(imode == IMode::INTER_LINEAR) { | |||||
| //! test NHWC with ch != 1 or ch != 3 | |||||
| if (imode == IMode::INTER_LINEAR) { | |||||
| //! test NHWC with ch != 1 or ch != 3 | |||||
| param::Resize param; | param::Resize param; | ||||
| param.format = param::Resize::Format::NHWC; | param.format = param::Resize::Format::NHWC; | ||||
| param.imode = imode; | param.imode = imode; | ||||
| args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 4, 6, 4}); | |||||
| args.emplace_back(param, TensorShape{2, 4, 6, 4}, TensorShape{2, 2, 3, 4}); | |||||
| args.emplace_back(param, TensorShape{2, 2, 3, 4}, | |||||
| TensorShape{2, 4, 6, 4}); | |||||
| args.emplace_back(param, TensorShape{2, 4, 6, 4}, | |||||
| TensorShape{2, 2, 3, 4}); | |||||
| } | } | ||||
| return args; | return args; | ||||
| } | } | ||||
| @@ -108,6 +111,48 @@ static inline std::vector<TestArg> get_nchw4_args() { | |||||
| return args; | return args; | ||||
| } | } | ||||
| static inline std::vector<TestArg> get_nchw44_args() { | |||||
| std::vector<TestArg> args; | |||||
| param::Resize param; | |||||
| param.format = param::Resize::Format::NCHW44; | |||||
| param.imode = param::Resize::InterpolationMode::LINEAR; | |||||
| rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) | |||||
| args.emplace_back( | |||||
| param, | |||||
| TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 4ul}, | |||||
| TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 4ul}); | |||||
| param.imode = param::Resize::InterpolationMode::NEAREST; | |||||
| rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) | |||||
| args.emplace_back( | |||||
| param, | |||||
| TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 4ul}, | |||||
| TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 4ul}); | |||||
| return args; | |||||
| } | |||||
| static inline std::vector<TestArg> get_nchw88_args() { | |||||
| std::vector<TestArg> args; | |||||
| param::Resize param; | |||||
| param.format = param::Resize::Format::NCHW88; | |||||
| param.imode = param::Resize::InterpolationMode::LINEAR; | |||||
| rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) | |||||
| args.emplace_back( | |||||
| param, | |||||
| TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 8ul}, | |||||
| TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 8ul}); | |||||
| param.imode = param::Resize::InterpolationMode::NEAREST; | |||||
| rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) | |||||
| args.emplace_back( | |||||
| param, | |||||
| TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 8ul}, | |||||
| TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 8ul}); | |||||
| return args; | |||||
| } | |||||
| static inline std::vector<TestArg> get_cv_args() { | static inline std::vector<TestArg> get_cv_args() { | ||||
| std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
| @@ -68,87 +68,90 @@ using namespace gopt; | |||||
| * oprs should not get involved in any actual computing. | * oprs should not get involved in any actual computing. | ||||
| */ | */ | ||||
| MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, | MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, | ||||
| cg::SingleCNOperatorNodeBase) // { | |||||
| cg::SingleCNOperatorNodeBase) // { | |||||
| public: | public: | ||||
| //! relayout type of this opr | |||||
| enum class LayoutType { | |||||
| NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout | |||||
| NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout | |||||
| NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout | |||||
| CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout | |||||
| NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout | |||||
| NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose | |||||
| ///< channel size less than 4 | |||||
| NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout | |||||
| NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout | |||||
| NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout | |||||
| WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 | |||||
| //!< layout | |||||
| WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to | |||||
| //!< nchw4 layout | |||||
| WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout | |||||
| //!< to nchw4 layout whose | |||||
| //! channel size less than 4 | |||||
| WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 | |||||
| //!< layout | |||||
| WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to | |||||
| //!< nchw88 layout | |||||
| WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout | |||||
| //!< to nchw88 layout | |||||
| //!< the weight layout of input is nchw output is nchw88, special for | |||||
| //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} | |||||
| WEIGHT_HYBIRD_NCHW_NCHW88, | |||||
| WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44 | |||||
| //!< layout | |||||
| WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to | |||||
| //!< nchw44 layout | |||||
| WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout | |||||
| //!< to nchw44 layout | |||||
| //!< the weight layout of input is nchw output is nchw44, special for | |||||
| //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} | |||||
| WEIGHT_HYBIRD_NCHW_NCHW44, | |||||
| WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to | |||||
| //!< NCHW44_DOT layout dense | |||||
| WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to | |||||
| //!< NCHW44_DOT layout group | |||||
| NCHW32_TO_NCHW, //! <from nchw32 layout to nchw layout | |||||
| NCHW32_TO_NCHW64, //! <from nchw32 layout to nchw64 layout | |||||
| NCHW64_TO_NCHW, //! <from nchw64 layout to nchw layout | |||||
| NCHW64_TO_NCHW4, //! <from nchw64 layout to nchw4 layout | |||||
| NCHW64_TO_NCHW32, //! <from nchw64 layout to nchw32 layout | |||||
| NCHW_TO_NCHW64, //! <from nchw layout to nchw64 layout | |||||
| NCHW_TO_NCHW32, //! <from nchw layout to nchw64 layout | |||||
| NCHW4_TO_NCHW64, //! <from nchw4 layout to nchw64 layout | |||||
| NCHW_TO_NHWC, //! <NHWC related layout transformation | |||||
| NCHW4_TO_NHWC, | |||||
| NCHW32_TO_NHWC, | |||||
| NCHW64_TO_NHWC, | |||||
| NHWC_TO_NCHW, | |||||
| NHWC_TO_NCHW4, | |||||
| NHWC_TO_NCHW32, | |||||
| NHWC_TO_NCHW64, | |||||
| }; | |||||
| //! relayout type of this opr | |||||
| enum class LayoutType { | |||||
| NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout | |||||
| NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout | |||||
| NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout | |||||
| CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout | |||||
| NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout | |||||
| NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose | |||||
| ///< channel size less than 4 | |||||
| NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout | |||||
| NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout | |||||
| NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout | |||||
| WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 | |||||
| //!< layout | |||||
| WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to | |||||
| //!< nchw4 layout | |||||
| WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout | |||||
| //!< to nchw4 layout whose | |||||
| //! channel size less than 4 | |||||
| WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 | |||||
| //!< layout | |||||
| WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to | |||||
| //!< nchw88 layout | |||||
| WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout | |||||
| //!< to nchw88 layout | |||||
| //!< the weight layout of input is nchw output is nchw88, special for | |||||
| //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} | |||||
| WEIGHT_HYBIRD_NCHW_NCHW88, | |||||
| WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44 | |||||
| //!< layout | |||||
| WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to | |||||
| //!< nchw44 layout | |||||
| WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout | |||||
| //!< to nchw44 layout | |||||
| //!< the weight layout of input is nchw output is nchw44, special for | |||||
| //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} | |||||
| WEIGHT_HYBIRD_NCHW_NCHW44, | |||||
| WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to | |||||
| //!< NCHW44_DOT layout dense | |||||
| WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to | |||||
| //!< NCHW44_DOT layout group | |||||
| NCHW32_TO_NCHW, //! <from nchw32 layout to nchw layout | |||||
| NCHW32_TO_NCHW64, //! <from nchw32 layout to nchw64 layout | |||||
| NCHW64_TO_NCHW, //! <from nchw64 layout to nchw layout | |||||
| NCHW64_TO_NCHW4, //! <from nchw64 layout to nchw4 layout | |||||
| NCHW64_TO_NCHW32, //! <from nchw64 layout to nchw32 layout | |||||
| NCHW_TO_NCHW64, //! <from nchw layout to nchw64 layout | |||||
| NCHW_TO_NCHW32, //! <from nchw layout to nchw64 layout | |||||
| NCHW4_TO_NCHW64, //! <from nchw4 layout to nchw64 layout | |||||
| NCHW_TO_NHWC, //! <NHWC related layout transformation | |||||
| NCHW4_TO_NHWC, | |||||
| NCHW32_TO_NHWC, | |||||
| NCHW64_TO_NHWC, | |||||
| NHWC_TO_NCHW, | |||||
| NHWC_TO_NCHW4, | |||||
| NHWC_TO_NCHW32, | |||||
| NHWC_TO_NCHW64, | |||||
| }; | |||||
| RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); | |||||
| RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); | |||||
| /*! | |||||
| * \param src_var the input var | |||||
| * \param layout_type tensor layout transform type of this relayout | |||||
| * placeholder as described in LayoutType | |||||
| */ | |||||
| static SymbolVar make(VarNode* src_var, LayoutType layout_type); | |||||
| /*! | |||||
| * \param src_var the input var | |||||
| * \param layout_type tensor layout transform type of this relayout | |||||
| * placeholder as described in LayoutType | |||||
| */ | |||||
| static SymbolVar make(VarNode* src_var, LayoutType layout_type); | |||||
| LayoutType layout_type() const { return m_layout_type; } | |||||
| LayoutType layout_type() const { | |||||
| return m_layout_type; | |||||
| } | |||||
| private: | private: | ||||
| void init_output_static_infer_desc() override; | |||||
| void scn_do_execute() override; | |||||
| void init_output_comp_node() override; | |||||
| const LayoutType m_layout_type; | |||||
| }; | |||||
| void init_output_static_infer_desc() override; | |||||
| void scn_do_execute() override; | |||||
| void init_output_comp_node() override; | |||||
| const LayoutType m_layout_type; | |||||
| } | |||||
| ; | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder); | ||||
| TensorReformatPass::RelayoutPlaceholder::RelayoutPlaceholder( | TensorReformatPass::RelayoutPlaceholder::RelayoutPlaceholder( | ||||
| @@ -1023,8 +1026,7 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
| auto sub = [&xshp, &cv](int idx) { | auto sub = [&xshp, &cv](int idx) { | ||||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | ||||
| }; | }; | ||||
| auto tshp0 = | |||||
| opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0); | |||||
| auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0); | |||||
| auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); | auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); | ||||
| auto y1 = opr::Reshape::make(y0, tshp0); | auto y1 = opr::Reshape::make(y0, tshp0); | ||||
| return y1.node(); | return y1.node(); | ||||
| @@ -1036,7 +1038,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
| auto sub = [&xshp, &cv](int idx) { | auto sub = [&xshp, &cv](int idx) { | ||||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | ||||
| }; | }; | ||||
| auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 32}, 0); | |||||
| auto tshp0 = | |||||
| opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 32}, 0); | |||||
| auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); | auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); | ||||
| auto y1 = opr::Reshape::make(y0, tshp0); | auto y1 = opr::Reshape::make(y0, tshp0); | ||||
| return y1.node(); | return y1.node(); | ||||
| @@ -1048,7 +1051,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
| auto sub = [&xshp, &cv](int idx) { | auto sub = [&xshp, &cv](int idx) { | ||||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | ||||
| }; | }; | ||||
| auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 64}, 0); | |||||
| auto tshp0 = | |||||
| opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 64}, 0); | |||||
| auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); | auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); | ||||
| auto y1 = opr::Reshape::make(y0, tshp0); | auto y1 = opr::Reshape::make(y0, tshp0); | ||||
| return y1.node(); | return y1.node(); | ||||
| @@ -1865,8 +1869,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||||
| }; | }; | ||||
| auto replace_deconv_opr = [trans_nchw4, conv_format]( | auto replace_deconv_opr = [trans_nchw4, conv_format]( | ||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| if (new_inp[1]->dtype().enumv() == DTypeEnum::Float32) { | if (new_inp[1]->dtype().enumv() == DTypeEnum::Float32) { | ||||
| return serialization::copy_opr_shallow(*opr, new_inp, | return serialization::copy_opr_shallow(*opr, new_inp, | ||||
| opr->config()); | opr->config()); | ||||
| @@ -1881,7 +1885,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||||
| opr->config()); | opr->config()); | ||||
| } | } | ||||
| VarNode *deconv_src = new_inp[1], *deconv_filter = new_inp[0]; | VarNode *deconv_src = new_inp[1], *deconv_filter = new_inp[0]; | ||||
| auto deconv_mode = trans_nchw4(deconv_opr.param().sparse, deconv_filter); | |||||
| auto deconv_mode = | |||||
| trans_nchw4(deconv_opr.param().sparse, deconv_filter); | |||||
| // src: NCHW --> NCWH4 | // src: NCHW --> NCWH4 | ||||
| if (deconv_src->shape().ndim != 5) { | if (deconv_src->shape().ndim != 5) { | ||||
| mgb_assert(deconv_src->shape().ndim == 4); | mgb_assert(deconv_src->shape().ndim == 4); | ||||
| @@ -2028,10 +2033,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||||
| conv_bias_src, conv_bias_filter, new_param, | conv_bias_src, conv_bias_filter, new_param, | ||||
| conv_bias_opr.execution_policy(), conv_bias_opr.config()); | conv_bias_opr.execution_policy(), conv_bias_opr.config()); | ||||
| OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); | OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); | ||||
| mgb_assert( | |||||
| new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || | |||||
| new_conv_bias_opr.shape().ndim == 5, | |||||
| "The conv_bias dst dim is not trans to nchw4"); | |||||
| mgb_assert(new_conv_bias_opr.node()->dtype().enumv() == | |||||
| DTypeEnum::Float32 || | |||||
| new_conv_bias_opr.shape().ndim == 5, | |||||
| "The conv_bias dst dim is not trans to nchw4"); | |||||
| return new_opr; | return new_opr; | ||||
| } | } | ||||
| // bias: NCHW --> NCHW4 when bias_dtype is not Float32 | // bias: NCHW --> NCHW4 when bias_dtype is not Float32 | ||||
| @@ -2047,10 +2052,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||||
| conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, | conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, | ||||
| conv_bias_opr.execution_policy(), conv_bias_opr.config()); | conv_bias_opr.execution_policy(), conv_bias_opr.config()); | ||||
| OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); | OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); | ||||
| mgb_assert( | |||||
| new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || | |||||
| new_conv_bias_opr.shape().ndim == 5, | |||||
| "The conv_bias dst dim is not trans to nchw4"); | |||||
| mgb_assert(new_conv_bias_opr.node()->dtype().enumv() == | |||||
| DTypeEnum::Float32 || | |||||
| new_conv_bias_opr.shape().ndim == 5, | |||||
| "The conv_bias dst dim is not trans to nchw4"); | |||||
| return new_opr; | return new_opr; | ||||
| } | } | ||||
| // z_inp: NCHW --> NCHW4 when bias_dtype is not Float32 | // z_inp: NCHW --> NCHW4 when bias_dtype is not Float32 | ||||
| @@ -2066,10 +2071,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||||
| new_param, conv_bias_opr.execution_policy(), | new_param, conv_bias_opr.execution_policy(), | ||||
| conv_bias_opr.config()); | conv_bias_opr.config()); | ||||
| OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); | OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); | ||||
| mgb_assert( | |||||
| new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || | |||||
| new_conv_bias_opr.shape().ndim == 5, | |||||
| "The conv_bias dst dim is not trans to nchw4"); | |||||
| mgb_assert(new_conv_bias_opr.node()->dtype().enumv() == | |||||
| DTypeEnum::Float32 || | |||||
| new_conv_bias_opr.shape().ndim == 5, | |||||
| "The conv_bias dst dim is not trans to nchw4"); | |||||
| return new_opr; | return new_opr; | ||||
| }; | }; | ||||
| auto replace_elemwise_opr = [=](OperatorNodeBase* opr, | auto replace_elemwise_opr = [=](OperatorNodeBase* opr, | ||||
| @@ -2210,8 +2215,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||||
| auto&& replace_func = ret->m_opr_replace_func; | auto&& replace_func = ret->m_opr_replace_func; | ||||
| //! supportted nchw4 | //! supportted nchw4 | ||||
| replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | ||||
| replace_func[opr::ConvolutionBackwardData::typeinfo()] = | |||||
| replace_deconv_opr; | |||||
| replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr; | |||||
| replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | ||||
| replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr; | replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr; | ||||
| replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | ||||
| @@ -2348,6 +2352,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
| megdnn::param::Convolution::Format::NCHW88; | megdnn::param::Convolution::Format::NCHW88; | ||||
| megdnn::param::Pooling::Format pooling_format = | megdnn::param::Pooling::Format pooling_format = | ||||
| megdnn::param::Pooling::Format::NCHW88; | megdnn::param::Pooling::Format::NCHW88; | ||||
| megdnn::param::Resize::Format resize_format = | |||||
| megdnn::param::Resize::Format::NCHW88; | |||||
| std::string convter_pass_name = "conv_format_nchw88"; | std::string convter_pass_name = "conv_format_nchw88"; | ||||
| if (pack_c_size == 4) { | if (pack_c_size == 4) { | ||||
| @@ -2360,6 +2366,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
| conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; | conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; | ||||
| conv_format = megdnn::param::Convolution::Format::NCHW44; | conv_format = megdnn::param::Convolution::Format::NCHW44; | ||||
| pooling_format = megdnn::param::Pooling::Format::NCHW44; | pooling_format = megdnn::param::Pooling::Format::NCHW44; | ||||
| resize_format = megdnn::param::Resize::Format::NCHW44; | |||||
| convter_pass_name = "conv_format_nchw44"; | convter_pass_name = "conv_format_nchw44"; | ||||
| } | } | ||||
| auto test_trans_nchwxx = | auto test_trans_nchwxx = | ||||
| @@ -2634,6 +2641,43 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
| return new_opr; | return new_opr; | ||||
| } | } | ||||
| }; | }; | ||||
| auto replace_resize_opr = [=](OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>(); | |||||
| mgb_throw_if( | |||||
| resize_opr.param().format != | |||||
| megdnn::param::Resize::Format::NCHW && | |||||
| resize_opr.param().format != | |||||
| megdnn::param::Resize::Format::NHWC, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHWxx"); | |||||
| VarNode* inp = new_inp[0]; | |||||
| if (resize_opr.param().format == megdnn::param::Resize::Format::NHWC) { | |||||
| auto temp_inp = new_inp; | |||||
| if (inp->shape().ndim == 5) { | |||||
| auto new_var = RelayoutPlaceholder::make(inp, src_to_nchw_mode); | |||||
| temp_inp[0] = new_var.node(); | |||||
| } | |||||
| return serialization::copy_opr_shallow(*opr, temp_inp, | |||||
| opr->config()); | |||||
| } else { | |||||
| auto temp_inp = new_inp; | |||||
| if (inp->shape().ndim == 5) { | |||||
| auto new_param = resize_opr.param(); | |||||
| new_param.format = resize_format; | |||||
| auto new_resize_opr = opr::ResizeForward::make( | |||||
| new_inp[0], new_inp[1], new_param, opr->config()); | |||||
| return new_resize_opr.node()->owner_opr(); | |||||
| } else { | |||||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||||
| opr->config()); | |||||
| } | |||||
| } | |||||
| }; | |||||
| //! When input change and all input can convert to nchwxx, this opr will run | //! When input change and all input can convert to nchwxx, this opr will run | ||||
| //! in nchwxx mode, else it will run in nchw mode, for example concat and | //! in nchwxx mode, else it will run in nchw mode, for example concat and | ||||
| //! elemwise opr | //! elemwise opr | ||||
| @@ -2704,6 +2748,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
| replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | ||||
| replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | ||||
| replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | ||||
| replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; | |||||
| replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr; | replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr; | ||||
| replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr; | replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr; | ||||
| replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr; | replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr; | ||||
| @@ -2718,7 +2763,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
| replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; | ||||
| replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; | ||||
| replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; | ||||
| replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; | |||||
| replace_func[opr::WarpPerspectiveForward::typeinfo()] = | replace_func[opr::WarpPerspectiveForward::typeinfo()] = | ||||
| relayout_inp_to_nchw; | relayout_inp_to_nchw; | ||||
| replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | ||||
| @@ -3236,26 +3280,27 @@ public: | |||||
| MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, | MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, | ||||
| cg::SingleCNOperatorNodeBase) // { | cg::SingleCNOperatorNodeBase) // { | ||||
| public: | public: | ||||
| AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, | |||||
| TensorFormat out_format); | |||||
| AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, | |||||
| TensorFormat out_format); | |||||
| static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, | |||||
| TensorFormat out_format); | |||||
| static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, | |||||
| TensorFormat out_format); | |||||
| TensorFormat inp_format() const { | |||||
| return m_inp_format; | |||||
| } | |||||
| TensorFormat inp_format() const { | |||||
| return m_inp_format; | |||||
| } | |||||
| TensorFormat out_format() const { | |||||
| return m_out_format; | |||||
| } | |||||
| TensorFormat out_format() const { | |||||
| return m_out_format; | |||||
| } | |||||
| private: | private: | ||||
| void init_output_static_infer_desc() override; | |||||
| void scn_do_execute() override; | |||||
| const TensorFormat m_inp_format; | |||||
| const TensorFormat m_out_format; | |||||
| }; | |||||
| void init_output_static_infer_desc() override; | |||||
| void scn_do_execute() override; | |||||
| const TensorFormat m_inp_format; | |||||
| const TensorFormat m_out_format; | |||||
| } | |||||
| ; | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr); | ||||
| @@ -3910,8 +3955,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||||
| opr_set.insert(opr); | opr_set.insert(opr); | ||||
| // check dimshuffle | // check dimshuffle | ||||
| auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||||
| reshape->input(0)->owner_opr()); | |||||
| auto shuffle = | |||||
| try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr()); | |||||
| if (shuffle == nullptr) | if (shuffle == nullptr) | ||||
| return false; | return false; | ||||
| auto&& param = shuffle->param(); | auto&& param = shuffle->param(); | ||||
| @@ -3981,10 +4026,9 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||||
| auto conv_bias_shuffle = opr::ConvBias::make( | auto conv_bias_shuffle = opr::ConvBias::make( | ||||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | src, filter, new_bias, new_param, conv_bias->execution_policy(), | ||||
| OperatorNodeConfig{out_dtype}); | OperatorNodeConfig{out_dtype}); | ||||
| rewriter.replace_var( | |||||
| opr->output(0), conv_bias_shuffle.node(), | |||||
| mgb_cstr_log("replace conv_bias + " | |||||
| "reformat to conv_bias(NCHW4_NHWC)")); | |||||
| rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), | |||||
| mgb_cstr_log("replace conv_bias + " | |||||
| "reformat to conv_bias(NCHW4_NHWC)")); | |||||
| return true; | return true; | ||||
| }; | }; | ||||
| @@ -4036,8 +4080,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||||
| return false; | return false; | ||||
| auto inp_dtype = conv_bias->input(0)->dtype(); | auto inp_dtype = conv_bias->input(0)->dtype(); | ||||
| bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | ||||
| conv_bias->param().format == | |||||
| megdnn::param::ConvBias::Format::NCHW32; | |||||
| conv_bias->param().format == | |||||
| megdnn::param::ConvBias::Format::NCHW32; | |||||
| if (!is_s8nchw32) | if (!is_s8nchw32) | ||||
| return false; | return false; | ||||
| if (conv_bias->input().size() != 3) | if (conv_bias->input().size() != 3) | ||||
| @@ -4078,9 +4122,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||||
| &rewriter](OperatorNodeBase* opr) { | &rewriter](OperatorNodeBase* opr) { | ||||
| if (!try_conv_dimshuffle_reshape_typecvt(opr) && | if (!try_conv_dimshuffle_reshape_typecvt(opr) && | ||||
| !try_conv_reformat_nchw42nchw32(opr) && | !try_conv_reformat_nchw42nchw32(opr) && | ||||
| !try_conv_reformat_nchw42nhwc(opr) | |||||
| && !try_conv_reformat_nchw322nchw4(opr) | |||||
| ) { | |||||
| !try_conv_reformat_nchw42nhwc(opr) && | |||||
| !try_conv_reformat_nchw322nchw4(opr)) { | |||||
| rewriter.auto_replace_outputs(opr); | rewriter.auto_replace_outputs(opr); | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -4497,7 +4540,7 @@ void PaddingChannelPass::apply(OptState& opt) const { | |||||
| /* ================ EnableNCHW64Pass =============== */ | /* ================ EnableNCHW64Pass =============== */ | ||||
| VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, | VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, | ||||
| VarNode* orig_var) const { | |||||
| VarNode* orig_var) const { | |||||
| if (!orig_var->shape().eq_shape(new_var->shape())) { | if (!orig_var->shape().eq_shape(new_var->shape())) { | ||||
| auto iter = m_opr_format_map.find(new_var->owner_opr()); | auto iter = m_opr_format_map.find(new_var->owner_opr()); | ||||
| mgb_assert(iter != m_opr_format_map.end(), | mgb_assert(iter != m_opr_format_map.end(), | ||||
| @@ -4532,8 +4575,7 @@ VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, | |||||
| return new_var; | return new_var; | ||||
| } | } | ||||
| std::unique_ptr<EnableNCHW64Pass> | |||||
| EnableNCHW64Pass::make_nchw64_converter() { | |||||
| std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() { | |||||
| MIDOUT_B("EnableNCHW64Pass::make") | MIDOUT_B("EnableNCHW64Pass::make") | ||||
| auto ret = std::make_unique<EnableNCHW64Pass>(); | auto ret = std::make_unique<EnableNCHW64Pass>(); | ||||
| ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ | ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ | ||||
| @@ -4618,15 +4660,15 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| [make_new_conv, &format_map]( | [make_new_conv, &format_map]( | ||||
| OperatorNodeBase* opr, | OperatorNodeBase* opr, | ||||
| const VarNodeArray& new_inp) -> VarNode* { | const VarNodeArray& new_inp) -> VarNode* { | ||||
| mgb_assert(opr->input().size()==new_inp.size()); | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| bool check_dtype = | bool check_dtype = | ||||
| new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && | new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && | ||||
| new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; | new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; | ||||
| mgb_assert(opr->output().size() > 0); | mgb_assert(opr->output().size() > 0); | ||||
| bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | ||||
| if (opr->input().size() >= 3) { | if (opr->input().size() >= 3) { | ||||
| auto dtype_expect = dst_float ? DTypeEnum::Float32 | |||||
| : DTypeEnum::QuantizedS32; | |||||
| auto dtype_expect = | |||||
| dst_float ? DTypeEnum::Float32 : DTypeEnum::QuantizedS32; | |||||
| check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect; | check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect; | ||||
| } | } | ||||
| if (opr->input().size() >= 4) { | if (opr->input().size() >= 4) { | ||||
| @@ -4677,12 +4719,13 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| for (size_t i = 0; i < inps.size(); ++i) { | for (size_t i = 0; i < inps.size(); ++i) { | ||||
| // do not format bias and z when dst_float is true | // do not format bias and z when dst_float is true | ||||
| bool skip = dst_float && i >= 2; | bool skip = dst_float && i >= 2; | ||||
| if (!skip) inps[i] = process(i); | |||||
| if (!skip) | |||||
| inps[i] = process(i); | |||||
| } | } | ||||
| auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>(); | auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>(); | ||||
| auto ret = make_new_conv( | |||||
| inps, &conv_bias, | |||||
| dst_float ? Format::NCHW4_NCHW : Format::NCHW4); | |||||
| auto ret = | |||||
| make_new_conv(inps, &conv_bias, | |||||
| dst_float ? Format::NCHW4_NCHW : Format::NCHW4); | |||||
| if (!dst_float) | if (!dst_float) | ||||
| format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); | format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); | ||||
| return ret; | return ret; | ||||
| @@ -4692,7 +4735,7 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| [make_new_conv, &format_map]( | [make_new_conv, &format_map]( | ||||
| OperatorNodeBase* opr, | OperatorNodeBase* opr, | ||||
| const VarNodeArray& new_inp) -> VarNode* { | const VarNodeArray& new_inp) -> VarNode* { | ||||
| mgb_assert(opr->input().size()==new_inp.size()); | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| bool check_dtype = | bool check_dtype = | ||||
| new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && | new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && | ||||
| new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; | new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; | ||||
| @@ -4754,18 +4797,17 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| OperatorNodeBase* opr, | OperatorNodeBase* opr, | ||||
| const VarNodeArray& new_inp) -> VarNode* { | const VarNodeArray& new_inp) -> VarNode* { | ||||
| // fint4XWint4 and fuint4XWint4 | // fint4XWint4 and fuint4XWint4 | ||||
| mgb_assert(opr->input().size()==new_inp.size()); | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| bool check_dtype = | bool check_dtype = | ||||
| (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || | (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || | ||||
| new_inp[0]->dtype().enumv() == | |||||
| DTypeEnum::Quantized4Asymm) && | |||||
| new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) && | |||||
| new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; | new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; | ||||
| if (opr->input().size() >= 3) | if (opr->input().size() >= 3) | ||||
| check_dtype &= | check_dtype &= | ||||
| new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; | new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; | ||||
| if (opr->input().size() >= 4) | if (opr->input().size() >= 4) | ||||
| check_dtype &= new_inp[3]->dtype().enumv() == | |||||
| new_inp[0]->dtype().enumv(); | |||||
| check_dtype &= | |||||
| new_inp[3]->dtype().enumv() == new_inp[0]->dtype().enumv(); | |||||
| if (!check_dtype) | if (!check_dtype) | ||||
| return nullptr; | return nullptr; | ||||
| size_t out_channels = opr->input(1)->shape()[0]; | size_t out_channels = opr->input(1)->shape()[0]; | ||||
| @@ -4818,18 +4860,17 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| OperatorNodeBase* opr, | OperatorNodeBase* opr, | ||||
| const VarNodeArray& new_inp) -> VarNode* { | const VarNodeArray& new_inp) -> VarNode* { | ||||
| // fint4XWint4 and fuint4XWint4 | // fint4XWint4 and fuint4XWint4 | ||||
| mgb_assert(opr->input().size()==new_inp.size()); | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| bool check_dtype = | bool check_dtype = | ||||
| (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || | (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || | ||||
| new_inp[0]->dtype().enumv() == | |||||
| DTypeEnum::Quantized4Asymm) && | |||||
| new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) && | |||||
| new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; | new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; | ||||
| if (opr->input().size() >= 3) | if (opr->input().size() >= 3) | ||||
| check_dtype &= | check_dtype &= | ||||
| new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; | new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; | ||||
| if (opr->input().size() >= 4) | if (opr->input().size() >= 4) | ||||
| check_dtype &= new_inp[3]->dtype().enumv() == | |||||
| new_inp[0]->dtype().enumv(); | |||||
| check_dtype &= | |||||
| new_inp[3]->dtype().enumv() == new_inp[0]->dtype().enumv(); | |||||
| if (!check_dtype) | if (!check_dtype) | ||||
| return nullptr; | return nullptr; | ||||
| size_t out_channels = opr->input(1)->shape()[0]; | size_t out_channels = opr->input(1)->shape()[0]; | ||||
| @@ -4842,8 +4883,7 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| auto iter = format_map.find(new_inp[i]->owner_opr()); | auto iter = format_map.find(new_inp[i]->owner_opr()); | ||||
| if (iter == format_map.end()) { | if (iter == format_map.end()) { | ||||
| auto ovar = RelayoutPlaceholder::make( | auto ovar = RelayoutPlaceholder::make( | ||||
| inps[i], | |||||
| RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC); | |||||
| inps[i], RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC); | |||||
| return ovar.node(); | return ovar.node(); | ||||
| } else { | } else { | ||||
| const auto& fmt = iter->second; | const auto& fmt = iter->second; | ||||
| @@ -4973,7 +5013,7 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| default: | default: | ||||
| mgb_assert(cur == Format::NCHW4); | mgb_assert(cur == Format::NCHW4); | ||||
| } | } | ||||
| auto param = deconv.param(); | auto param = deconv.param(); | ||||
| param.format = Format::NCHW4; | param.format = Format::NCHW4; | ||||
| auto new_deconv = opr::ConvolutionBackwardData::make( | auto new_deconv = opr::ConvolutionBackwardData::make( | ||||
| @@ -4990,7 +5030,7 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| mgb_assert(!shape_changed, | |||||
| mgb_assert(!shape_changed, | |||||
| "EnableNCHW64Pass won't change format of output tensor " | "EnableNCHW64Pass won't change format of output tensor " | ||||
| "of non quantized deconv operator(name:%s)", | "of non quantized deconv operator(name:%s)", | ||||
| opr->cname()); | opr->cname()); | ||||
| @@ -5000,8 +5040,9 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| }; | }; | ||||
| // replace rule for elemwise like opr | // replace rule for elemwise like opr | ||||
| auto replace_elemwise_like_opr = [&format_map](OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| auto replace_elemwise_like_opr = [&format_map]( | |||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| ThinHashMap<Format, size_t> format_size; | ThinHashMap<Format, size_t> format_size; | ||||
| bool same_format = true; | bool same_format = true; | ||||
| @@ -5073,7 +5114,7 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| cur = Format::NCHW; | cur = Format::NCHW; | ||||
| } | } | ||||
| if (cur != max_format) { | if (cur != max_format) { | ||||
| inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]); | |||||
| inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]); | |||||
| } | } | ||||
| } | } | ||||
| auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); | auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); | ||||
| @@ -5131,8 +5172,7 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| SymbolVar new_warp; | SymbolVar new_warp; | ||||
| if (inps.size() == 3) { | if (inps.size() == 3) { | ||||
| new_warp = opr::WarpPerspectiveForward::make( | new_warp = opr::WarpPerspectiveForward::make( | ||||
| inps[0], inps[1], inps[2], param, | |||||
| warp.config()); | |||||
| inps[0], inps[1], inps[2], param, warp.config()); | |||||
| } else { | } else { | ||||
| mgb_assert(inps.size() == 4); | mgb_assert(inps.size() == 4); | ||||
| new_warp = opr::WarpPerspectiveForward::make( | new_warp = opr::WarpPerspectiveForward::make( | ||||
| @@ -5179,14 +5219,13 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| default: | default: | ||||
| mgb_assert(cur == Format::NCHW4); | mgb_assert(cur == Format::NCHW4); | ||||
| } | } | ||||
| auto param = warp.param(); | auto param = warp.param(); | ||||
| param.format = Format::NCHW4; | param.format = Format::NCHW4; | ||||
| SymbolVar new_warp; | SymbolVar new_warp; | ||||
| if (inps.size() == 3) { | if (inps.size() == 3) { | ||||
| new_warp = opr::WarpPerspectiveForward::make( | new_warp = opr::WarpPerspectiveForward::make( | ||||
| inps[0], inps[1], inps[2], param, | |||||
| warp.config()); | |||||
| inps[0], inps[1], inps[2], param, warp.config()); | |||||
| } else { | } else { | ||||
| mgb_assert(inps.size() == 4); | mgb_assert(inps.size() == 4); | ||||
| new_warp = opr::WarpPerspectiveForward::make( | new_warp = opr::WarpPerspectiveForward::make( | ||||
| @@ -5204,7 +5243,7 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| mgb_assert(!shape_changed, | |||||
| mgb_assert(!shape_changed, | |||||
| "EnableNCHW64Pass won't change format of output tensor " | "EnableNCHW64Pass won't change format of output tensor " | ||||
| "of non quantized warp perspective operator(name:%s)", | "of non quantized warp perspective operator(name:%s)", | ||||
| opr->cname()); | opr->cname()); | ||||
| @@ -5212,9 +5251,8 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| opr->config()); | opr->config()); | ||||
| } | } | ||||
| }; | }; | ||||
| auto replace_pooling_opr = [&format_map]( | |||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| auto replace_pooling_opr = [&format_map](OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| auto& pooling = opr->cast_final_safe<opr::PoolingForward>(); | auto& pooling = opr->cast_final_safe<opr::PoolingForward>(); | ||||
| if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || | if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || | ||||
| @@ -5300,7 +5338,7 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| mgb_assert(cur == Format::NCHW4); | mgb_assert(cur == Format::NCHW4); | ||||
| } | } | ||||
| Format out_format = use_nchw32 ? Format::NCHW32 : Format::NCHW4; | Format out_format = use_nchw32 ? Format::NCHW32 : Format::NCHW4; | ||||
| auto param = pooling.param(); | auto param = pooling.param(); | ||||
| param.format = out_format; | param.format = out_format; | ||||
| auto new_pool = | auto new_pool = | ||||
| @@ -5336,7 +5374,7 @@ EnableNCHW64Pass::make_nchw64_converter() { | |||||
| auto inps = new_inp; | auto inps = new_inp; | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| auto iter = format_map.find(new_inp[i]->owner_opr()); | auto iter = format_map.find(new_inp[i]->owner_opr()); | ||||
| auto fmt = iter != format_map.end()?iter->second:Format::NCHW; | |||||
| auto fmt = iter != format_map.end() ? iter->second : Format::NCHW; | |||||
| if (iter != format_map.end()) { | if (iter != format_map.end()) { | ||||
| switch (fmt) { | switch (fmt) { | ||||
| case Format::NHWC: | case Format::NHWC: | ||||
| @@ -10,9 +10,9 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "megbrain/opr/imgproc.h" | |||||
| #include "./internal/megdnn_opr_wrapper.inl" | #include "./internal/megdnn_opr_wrapper.inl" | ||||
| #include "megbrain/graph/grad_impl.h" | #include "megbrain/graph/grad_impl.h" | ||||
| #include "megbrain/opr/imgproc.h" | |||||
| #include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
| #include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
| @@ -340,7 +340,9 @@ void ResizeForward::outshape_by_symvar_do_get_output_shape( | |||||
| //! The index of height, e.g.,[b, h, w, c], the height_idx = 1 | //! The index of height, e.g.,[b, h, w, c], the height_idx = 1 | ||||
| size_t height_idx = 0; | size_t height_idx = 0; | ||||
| if (param().format == Param::Format::NCHW || | if (param().format == Param::Format::NCHW || | ||||
| param().format == Param::Format::NCHW4) { | |||||
| param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW88) { | |||||
| height_idx = 2; | height_idx = 2; | ||||
| } else { | } else { | ||||
| height_idx = 1; | height_idx = 1; | ||||