You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mask_conv.cpp 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. /**
  2. * \file dnn/src/common/mask_conv.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/nn.h"
  12. #include "src/common/utils.h"
  13. using namespace megdnn;
  14. void MaskConvForward::deduce_dtype(DType src, DType filter, DType, DType& dst) {
  15. check_or_deduce_dtype_fwd(src, filter, dst);
  16. }
  17. void MaskConvForward::deduce_layout(
  18. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& mask,
  19. TensorLayout& dst) {
  20. deduce_layout_fwd(src, filter, dst);
  21. megdnn_assert(dst[2] == mask[0]);
  22. megdnn_assert(dst[3] == mask[1]);
  23. }
  24. MaskConvForward::CanonizedFilterMeta MaskConvForward::check_exec(
  25. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& mask,
  26. const TensorLayout& dst, size_t workspace_in_bytes) {
  27. auto ret = check_layout_fwd(src, filter, dst);
  28. megdnn_assert(dst[2] == mask[0]);
  29. megdnn_assert(dst[3] == mask[1]);
  30. auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, mask, dst);
  31. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  32. return ret;
  33. }
  34. void MaskPropagate::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  35. size_t oh, ow;
  36. auto p = param();
  37. infer_conv_shape2d(
  38. src[0], src[1], (p.kernel_h - 1) * p.dilate_h + 1,
  39. (p.kernel_w - 1) * p.dilate_w + 1, p.stride_h, p.stride_w, p.pad_h, p.pad_w,
  40. oh, ow);
  41. dst = TensorLayout{{oh, ow}, src.dtype};
  42. }
  43. // vim: syntax=cpp.doxygen