GitOrigin-RevId: 07f2ee6c5b
tags/v0.6.0
| @@ -434,7 +434,7 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||||
| 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | ||||
| Doc('MK8', 'Split 8 from M and K, better for neon compute:' | Doc('MK8', 'Split 8 from M and K, better for neon compute:' | ||||
| '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | ||||
| 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||||
| 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||||
| Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | ||||
| 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | ||||
| 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | ||||
| @@ -858,7 +858,10 @@ when the ``I`` suffix is present. | |||||
| 'NCHW_NCHW88_CONV_CHAN_WEIGHT', | 'NCHW_NCHW88_CONV_CHAN_WEIGHT', | ||||
| 'NCHW_NCHW88_CONV_GROUP_WEIGHT', | 'NCHW_NCHW88_CONV_GROUP_WEIGHT', | ||||
| 'NCHW_NCHW88', | 'NCHW_NCHW88', | ||||
| 'NCHW88_NCHW') | |||||
| 'NCHW88_NCHW', | |||||
| 'NCHW_NCHW4_IC_SMALL', | |||||
| 'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT', | |||||
| ) | |||||
| ) | ) | ||||
| @@ -28,6 +28,26 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||||
| dst[3] = src[3]; | dst[3] = src[3]; | ||||
| dst[4] = 4; | dst[4] = 4; | ||||
| break; | break; | ||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
| dst.ndim = 5; | |||||
| megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); | |||||
| dst[0] = src[0]; | |||||
| dst[1] = div_ceil(src[1], 4_z); | |||||
| dst[2] = src[2]; | |||||
| dst[3] = src[3]; | |||||
| dst[4] = 4; | |||||
| break; | |||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
| megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4"); | |||||
| megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); | |||||
| dst.ndim = 5; | |||||
| dst[0] = src[0]; | |||||
| dst[1] = div_ceil(src[1], 4_z); | |||||
| dst[2] = src[2]; | |||||
| dst[3] = src[3]; | |||||
| dst[4] = 4; | |||||
| break; | |||||
| case Param::Mode::NCHW_NCHW88: | case Param::Mode::NCHW_NCHW88: | ||||
| dst.ndim = 5; | dst.ndim = 5; | ||||
| dst[0] = src[0]; | dst[0] = src[0]; | ||||
| @@ -276,6 +296,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||||
| case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: | ||||
| case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: | ||||
| case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: | ||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
| CHECK_SRC(DefaultTensorFormat::make()); | CHECK_SRC(DefaultTensorFormat::make()); | ||||
| dst = src; | dst = src; | ||||
| break; | break; | ||||
| @@ -374,6 +396,23 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||||
| exec_dst = dst; | exec_dst = dst; | ||||
| } | } | ||||
| break; | break; | ||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
| // nchw to nchw4c or oihw to oihw4i | |||||
| { | |||||
| TensorLayout work_space_layout( | |||||
| {src[0], round_up(src[1], 4_z), src[2], src[3]}, | |||||
| src.dtype, src.format); | |||||
| exec_src = work_space_layout | |||||
| .reshape({src[0], div_ceil(src[1], 4_z), 4, | |||||
| src[2], src[3]}) | |||||
| .dimshuffle({0, 1, 3, 4, 2}); | |||||
| exec_dst = dst; | |||||
| } | |||||
| break; | |||||
| case Param::Mode::NCHW_NHWCD4: | case Param::Mode::NCHW_NHWCD4: | ||||
| case Param::Mode::NCHW_NHWCD4I: | case Param::Mode::NCHW_NHWCD4I: | ||||
| // src is {N, C, H, W} | // src is {N, C, H, W} | ||||
| @@ -11,6 +11,7 @@ | |||||
| #include "src/cuda/relayout_format/opr_impl.h" | #include "src/cuda/relayout_format/opr_impl.h" | ||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| @@ -20,15 +21,22 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| auto src_dtype = src.layout.dtype; | auto src_dtype = src.layout.dtype; | ||||
| megdnn_assert( | megdnn_assert( | ||||
| param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | ||||
| param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4, | |||||
| param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 || | |||||
| param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||||
| param().mode == | |||||
| Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT, | |||||
| "relayout format of cuda only support NCHW4->CHWN4 or " | "relayout format of cuda only support NCHW4->CHWN4 or " | ||||
| "CHWN4->NCHW4"); | |||||
| if (src_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
| "CHWN4->NCHW4 or NCHW->NCHW4"); | |||||
| if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | |||||
| param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4) && | |||||
| src_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
| size_t row = 0, col = 0; | size_t row = 0, col = 0; | ||||
| if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { | if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { | ||||
| row = src.layout[0], | row = src.layout[0], | ||||
| col = src.layout[1] * src.layout[2] * src.layout[3]; | col = src.layout[1] * src.layout[2] * src.layout[3]; | ||||
| } else { | } else { | ||||
| megdnn_assert(param().mode == | |||||
| param::RelayoutFormat::Mode::CHWN4_NCHW4); | |||||
| row = src.layout[0] * src.layout[1] * src.layout[2], | row = src.layout[0] * src.layout[1] * src.layout[2], | ||||
| col = src.layout[3]; | col = src.layout[3]; | ||||
| } | } | ||||
| @@ -43,6 +51,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| return handle()->create_operator<RelayoutForward>()->exec(trans_in, | return handle()->create_operator<RelayoutForward>()->exec(trans_in, | ||||
| trans_out); | trans_out); | ||||
| } | } | ||||
| if ((param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||||
| param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) && | |||||
| src.layout[1] % 4 != 0) { | |||||
| megdnn_assert(src.raw_ptr != dst.raw_ptr && src.layout.ndim == 4, | |||||
| "The mode of NCHW_NCHW4 and NCHW_NCHW4_CONV_DENSE_WEIGHT " | |||||
| "of RelayoutFormat opr(cuda backend) does not support " | |||||
| "src.ptr == dst.ptr"); | |||||
| megdnn_assert(src.layout[1] <= 4); | |||||
| cuda_check(cudaMemsetAsync(dst.raw_ptr, 0, | |||||
| dst.layout.span().dist_byte(), | |||||
| cuda_stream(this->handle()))); | |||||
| TensorLayout exec_dst_layout = dst.layout; | |||||
| exec_dst_layout[4] = src.layout[1]; | |||||
| TensorLayout exec_src_layout = | |||||
| src.layout | |||||
| .reshape({src.layout[0], src.layout[1], 1, | |||||
| src.layout[2], src.layout[3]}) | |||||
| .dimshuffle({0, 2, 3, 4, 1}); | |||||
| return handle()->create_operator<RelayoutForward>()->exec( | |||||
| {src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); | |||||
| } | |||||
| TensorLayout exec_src, exec_dst; | TensorLayout exec_src, exec_dst; | ||||
| deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); | deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); | ||||
| TensorND exec_src_nd{src.raw_ptr, exec_src}; | TensorND exec_src_nd{src.raw_ptr, exec_src}; | ||||
| @@ -79,6 +79,7 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, | |||||
| } | } | ||||
| cb(Float32, dt_float32); | cb(Float32, dt_float32); | ||||
| cb(QuantizedS8, dt_qint8); | |||||
| default: | default: | ||||
| megdnn_assert(0); | megdnn_assert(0); | ||||
| #undef cb | #undef cb | ||||
| @@ -138,7 +139,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| return n * c * h * w * src.dtype.size(); | return n * c * h * w * src.dtype.size(); | ||||
| } | } | ||||
| case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { | ||||
| megdnn_assert(src.ndim == 4, "src must be oihw ,nmdim == 5"); | |||||
| megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); | |||||
| megdnn_assert(src[0] % 8 == 0, | megdnn_assert(src[0] % 8 == 0, | ||||
| "NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); | "NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); | ||||
| if (src[1] % 8 == 0) | if (src[1] % 8 == 0) | ||||
| @@ -150,7 +151,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| return oc * ic * h * w * src.dtype.size(); | return oc * ic * h * w * src.dtype.size(); | ||||
| } | } | ||||
| case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { | ||||
| megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); | |||||
| megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); | |||||
| megdnn_assert(src[1] % 8 == 0, | megdnn_assert(src[1] % 8 == 0, | ||||
| "NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " | "NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " | ||||
| "align to 8"); | "align to 8"); | ||||
| @@ -164,7 +165,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| return group * ocpg * icpg * h * w * src.dtype.size(); | return group * ocpg * icpg * h * w * src.dtype.size(); | ||||
| } | } | ||||
| case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { | ||||
| megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); | |||||
| megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); | |||||
| if (src[0] % 8 == 0) | if (src[0] % 8 == 0) | ||||
| return 0; | return 0; | ||||
| size_t group = round_up(src[0], 8_z); | size_t group = round_up(src[0], 8_z); | ||||
| @@ -174,6 +175,27 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| size_t w = src[4]; | size_t w = src[4]; | ||||
| return group * ocpg * icpg * h * w * src.dtype.size(); | return group * ocpg * icpg * h * w * src.dtype.size(); | ||||
| } | } | ||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL: { | |||||
| if (src[1] % 4 == 0) | |||||
| return 0; | |||||
| size_t n = src[0]; | |||||
| size_t c = round_up(src[1], 4_z); | |||||
| size_t h = src[2]; | |||||
| size_t w = src[3]; | |||||
| return n * c * h * w * src.dtype.size(); | |||||
| } | |||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: { | |||||
| megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); | |||||
| if (src[1] % 4 == 0) | |||||
| return 0; | |||||
| size_t oc = src[0]; | |||||
| size_t ic = round_up(src[1], 4_z); | |||||
| size_t h = src[2]; | |||||
| size_t w = src[3]; | |||||
| return oc * ic * h * w * src.dtype.size(); | |||||
| } | |||||
| default: | default: | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -244,31 +266,28 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; | exec_src_nd.raw_ptr = workspace.raw_ptr; | ||||
| } | } | ||||
| } else if (param().mode == Param::Mode::NCHW_NCHW88) { | } else if (param().mode == Param::Mode::NCHW_NCHW88) { | ||||
| size_t ic = src.layout[1]; | |||||
| if (ic % 8 != 0) { | |||||
| padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8); | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
| } | |||||
| #define cb(_idx, _pack_size) \ | |||||
| size_t val = src.layout[_idx]; \ | |||||
| if (val % _pack_size != 0) { \ | |||||
| padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ | |||||
| _pack_size); \ | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; \ | |||||
| } | |||||
| cb(1, 8); | |||||
| } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { | ||||
| megdnn_assert(src.layout[0] % 8 == 0); | megdnn_assert(src.layout[0] % 8 == 0); | ||||
| size_t ic = src.layout[1]; | |||||
| if (ic % 8 != 0) { | |||||
| padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8_z); | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
| } | |||||
| cb(1, 8); | |||||
| } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { | ||||
| size_t group = src.layout[0]; | |||||
| if (group % 8 != 0) { | |||||
| padding_to_workspace({workspace.raw_ptr, exec_src}, src, 0, 8_z); | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
| } | |||||
| cb(0, 8); | |||||
| } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { | ||||
| megdnn_assert(src.layout[1] % 8 == 0); | megdnn_assert(src.layout[1] % 8 == 0); | ||||
| size_t ic = src.layout[2]; | |||||
| if (ic % 8 != 0) { | |||||
| padding_to_workspace({workspace.raw_ptr, exec_src}, src, 2, 8_z); | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
| } | |||||
| cb(2, 8); | |||||
| } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { | |||||
| cb(1, 4); | |||||
| } else if (param().mode == | |||||
| Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { | |||||
| cb(1, 4); | |||||
| } | } | ||||
| m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | ||||
| } | } | ||||
| @@ -8,6 +8,7 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include "megdnn/dtype.h" | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "test/common/checker.h" | #include "test/common/checker.h" | ||||
| #include "test/common/rng.h" | #include "test/common/rng.h" | ||||
| @@ -30,4 +31,25 @@ TEST_F(CUDA, RELAYOUT_FORMAT) { | |||||
| checker.execs({{22, 23, 24, 25, 4}, {}}); | checker.execs({{22, 23, 24, 25, 4}, {}}); | ||||
| } | } | ||||
| TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) { | |||||
| Checker<RelayoutFormat> checker(handle_cuda()); | |||||
| UniformIntRNG rng{-50, 50}; | |||||
| param::RelayoutFormat param; | |||||
| param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL; | |||||
| for (DType dtype : | |||||
| std::vector<DType>({dtype::QuantizedS8{0.1f}, dtype::Float32{}})) { | |||||
| checker.set_dtype(0, dtype).set_rng(0, &rng); | |||||
| checker.set_param(param).execs({{2, 4, 35, 36}, {}}); | |||||
| checker.set_param(param).execs({{2, 3, 35, 36}, {}}); | |||||
| checker.set_param(param).execs({{2, 1, 35, 36}, {}}); | |||||
| param.mode = param::RelayoutFormat::Mode:: | |||||
| NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT; | |||||
| checker.set_param(param).execs({{4, 3, 3, 3}, {}}); | |||||
| checker.set_param(param).execs({{4, 4, 3, 3}, {}}); | |||||
| checker.set_param(param).execs({{1, 4, 3, 3}, {}}); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||