You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mask_conv.cpp 1.4 kB

123456789101112131415161718192021222324252627282930313233343536373839
  1. #include "megdnn/oprs/nn.h"
  2. #include "src/common/utils.h"
  3. using namespace megdnn;
  4. void MaskConvForward::deduce_dtype(DType src, DType filter, DType, DType& dst) {
  5. check_or_deduce_dtype_fwd(src, filter, dst);
  6. }
  7. void MaskConvForward::deduce_layout(
  8. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& mask,
  9. TensorLayout& dst) {
  10. deduce_layout_fwd(src, filter, dst);
  11. megdnn_assert(dst[2] == mask[0]);
  12. megdnn_assert(dst[3] == mask[1]);
  13. }
  14. MaskConvForward::CanonizedFilterMeta MaskConvForward::check_exec(
  15. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& mask,
  16. const TensorLayout& dst, size_t workspace_in_bytes) {
  17. auto ret = check_layout_fwd(src, filter, dst);
  18. megdnn_assert(dst[2] == mask[0]);
  19. megdnn_assert(dst[3] == mask[1]);
  20. auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, mask, dst);
  21. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  22. return ret;
  23. }
  24. void MaskPropagate::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  25. size_t oh, ow;
  26. auto p = param();
  27. infer_conv_shape2d(
  28. src[0], src[1], (p.kernel_h - 1) * p.dilate_h + 1,
  29. (p.kernel_w - 1) * p.dilate_w + 1, p.stride_h, p.stride_w, p.pad_h, p.pad_w,
  30. oh, ow);
  31. dst = TensorLayout{{oh, ow}, src.dtype};
  32. }
  33. // vim: syntax=cpp.doxygen