/** * \file dnn/src/common/local_share/opr_impl.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megdnn/oprs.h" #include "src/common/utils.h" namespace megdnn { void LocalShareBase::deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) { using Mode = LocalShare::Param::Mode; auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " + megdnn_layout_msg(dst) + ", " + "is_xcorr=" + std::to_string((param().mode == Mode::CROSS_CORRELATION)) + ", " + "pad_h=" + std::to_string(param().pad_h) + ", " + "pad_w=" + std::to_string(param().pad_w) + ", " + "stride_h=" + std::to_string(param().stride_h) + ", " + "stride_w=" + std::to_string(param().stride_w) + ", " + "dilate_h=" + std::to_string(param().dilate_h) + ", " + "dilate_w=" + std::to_string(param().dilate_w) + ", " + "spatial_groups_h=" + std::to_string(param().spatial_groups_h) + ", " + "spatial_groups_w=" + std::to_string(param().spatial_groups_w); auto errmsg_c = errmsg.c_str(); MEGDNN_MARK_USED_VAR(errmsg_c); megdnn_assert_contiguous(src); megdnn_assert_contiguous(filter); using Param = LocalShare::Param; using Sparse = Param::Sparse; using Format = Param::Format; using ComputeMode = Param::ComputeMode; megdnn_assert(param().format == Format::NCHW, "local shared only support NCHW format"); megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); megdnn_assert( (filter.ndim == 6_z && param().sparse == Sparse::DENSE) || (filter.ndim == 7_z && param().sparse == Sparse::GROUP), "%s", errmsg_c); megdnn_assert(param().dilate_h == 1 && param().dilate_w == 1, "dilated local shared is not supported"); megdnn_assert(src.dtype == dtype::Float32() && param().computeMode == ComputeMode::DEFAULT, "local shared only support float32"); size_t n = src[0], ci = src[1], hi = src[2], wi = src[3]; size_t sgh = param().spatial_groups_h, sgw = param().spatial_groups_w; size_t groups = 1; size_t weights_shp_pos = 0; if (param().sparse == Sparse::GROUP) { groups = filter[0]; weights_shp_pos = 1; } megdnn_assert(sgh == filter[weights_shp_pos] && sgw == filter[weights_shp_pos + 1], "spatial groups in filter tensor mismatch with those " "provided in parameter %s", errmsg_c); size_t fh = filter[weights_shp_pos + 3], fw = filter[weights_shp_pos + 4], co = filter[weights_shp_pos + 5] * groups; megdnn_assert(filter[weights_shp_pos + 2] * groups == ci, "input channels of src and filter mismatch %s", errmsg_c); size_t sh = param().stride_h; size_t sw = param().stride_w; size_t ph = param().pad_h; size_t pw = param().pad_w; size_t ho = infer_conv_shape(hi, fh, sh, ph), wo = infer_conv_shape(wi, fw, sw, pw); megdnn_assert( ho % sgh == 0 && wo % sgw == 0, "height and width of output cannot be divided by spatial groups %s", errmsg_c); dst = TensorLayout{{n, co, ho, wo}, src.dtype}; } void LocalShareBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { TensorLayout dst_expected; megdnn_assert_eq_dtype(src, filter); megdnn_assert_eq_dtype(src, dst); deduce_layout_fwd(src, filter, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); megdnn_assert(src.dtype == dtype::Float32()); } void LocalShareForward::deduce_layout(const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) { deduce_layout_fwd(src, filter, dst); } void LocalShareForward::check_exec(const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, filter, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } void LocalShareBackwardData::deduce_layout(const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) { using Mode = LocalShare::Param::Mode; auto errmsg = megdnn_layout_msg(filter) + ", " + megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(grad) + ", " + "is_xcorr=" + std::to_string((param().mode == Mode::CROSS_CORRELATION)) + ", " + "pad_h=" + std::to_string(param().pad_h) + ", " + "pad_w=" + std::to_string(param().pad_w) + ", " + "stride_h=" + std::to_string(param().stride_h) + ", " + "stride_w=" + std::to_string(param().stride_w) + ", " + "dilate_h=" + std::to_string(param().dilate_h) + ", " + "dilate_w=" + std::to_string(param().dilate_w) + ", " + "spatial_groups_h=" + std::to_string(param().spatial_groups_h) + ", " + "spatial_groups_w=" + std::to_string(param().spatial_groups_w); auto errmsg_c = errmsg.c_str(); MEGDNN_MARK_USED_VAR(errmsg_c); megdnn_assert_contiguous(filter); megdnn_assert_contiguous(diff); using Param = LocalShare::Param; using Sparse = Param::Sparse; using Format = Param::Format; using ComputeMode = Param::ComputeMode; megdnn_assert(param().format == Format::NCHW, "local shared only support NCHW format"); megdnn_assert( (filter.ndim == 6_z && param().sparse == Sparse::DENSE) || (filter.ndim == 7_z && param().sparse == Sparse::GROUP), "%s", errmsg_c); megdnn_assert(diff.ndim == 4_z, "%s", errmsg_c); megdnn_assert(param().dilate_h == 1 && param().dilate_w == 1, "dilated local shared is not supported"); megdnn_assert(diff.dtype == dtype::Float32() && param().computeMode == ComputeMode::DEFAULT, "local shared only support float32"); size_t n = diff[0], co = diff[1], ho = diff[2], wo = diff[3]; size_t sgh = param().spatial_groups_h, sgw = param().spatial_groups_w; megdnn_assert( ho % sgh == 0 && wo % sgw == 0, "height and width of output cannot be divided by spatial groups %s", errmsg_c); size_t groups = 1; size_t weights_shp_pos = 0; if (param().sparse == Sparse::GROUP) { groups = filter[0]; weights_shp_pos = 1; } megdnn_assert(sgh == filter[weights_shp_pos] && sgw == filter[weights_shp_pos + 1], "spatial groups in filter tensor mismatch with those " "provided in parameter %s", errmsg_c); size_t ci = filter[weights_shp_pos + 2] * groups, fh = filter[weights_shp_pos + 3], fw = filter[weights_shp_pos + 4]; megdnn_assert(filter[weights_shp_pos + 5] * groups == co, "input channels of src and filter mismatch %s", errmsg_c); size_t sh = param().stride_h; size_t sw = param().stride_w; size_t ph = param().pad_h; size_t pw = param().pad_w; auto deduce = [&errmsg_c](size_t out, size_t filter, size_t stride, size_t pad) { MEGDNN_MARK_USED_VAR(errmsg_c); auto i = (out - 1) * stride + filter; megdnn_assert(i > pad * 2, "%s", errmsg_c); return i - pad * 2; }; grad.ndim = diff.ndim; grad[0] = n; grad[1] = ci; grad[2] = deduce(ho, fh, sh, ph); grad[3] = deduce(wo, fw, sw, pw); grad.init_contiguous_stride(); grad.dtype = diff.dtype; } void LocalShareBackwardData::check_exec(const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) { auto filter_dtype = filter.dtype, diff_dtype = diff.dtype, grad_dtype = grad.dtype; megdnn_assert(filter_dtype == dtype::Float32() && filter_dtype == diff_dtype && filter_dtype == grad_dtype); check_layout_fwd(grad, filter, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } void LocalShareBackwardFilter::check_exec(const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) { auto src_dtype = src.dtype, diff_dtype = diff.dtype, grad_dtype = grad.dtype; megdnn_assert(src_dtype == dtype::Float32() && src_dtype == diff_dtype && src_dtype == grad_dtype); check_layout_fwd(src, grad, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } } // namespace megdnn // vim: syntax=cpp.doxygen