You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dct.cpp 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #include "megdnn/oprs.h"
  2. #include "src/common/utils.h"
  3. namespace megdnn {
  4. void DctChannelSelectForward::deduce_layout_fwd(
  5. const TensorLayout& src, const TensorLayout& mask_offset,
  6. const TensorLayout& mask_val, TensorLayout& dst) {
  7. const size_t dct_block = param().dct_block_size;
  8. const size_t in = src.shape[0];
  9. const size_t ic = src.shape[1];
  10. const size_t ih = src.shape[2];
  11. const size_t iw = src.shape[3];
  12. check_layout_fwd(src, mask_offset, mask_val, dst);
  13. const size_t oh = ih / dct_block;
  14. const size_t ow = iw / dct_block;
  15. //! mask will be empty or (ic + 1) elements
  16. size_t oc = mask_offset.ndim > 0 && mask_offset[0] >= 2
  17. ? mask_val.shape[0]
  18. : ic * dct_block * dct_block;
  19. if (param().fastImpl == Param::FastImpl::FIX_32_MASK) {
  20. megdnn_assert(
  21. oc == 32, "Param::FastImpl::FIX_32_MASK oc must be 32, but %zu", oc);
  22. }
  23. if (param().format == Param::Format::NCHW) {
  24. dst = TensorLayout(TensorShape({in, oc, oh, ow}), dst.dtype);
  25. } else {
  26. megdnn_assert(
  27. param().format == Param::Format::NCHW4,
  28. "dct format must be nchw or nchw4");
  29. megdnn_assert(oc % 4 == 0, "oc mod 4 == 0 in nchw4");
  30. dst = TensorLayout(TensorShape({in, oc / 4, oh, ow, 4}), dst.dtype);
  31. }
  32. }
  33. void DctChannelSelectForward::deduce_layout(
  34. const TensorLayout& src, const TensorLayout& mask_offset,
  35. const TensorLayout& mask_val, TensorLayout& dst) {
  36. deduce_layout_fwd(src, mask_offset, mask_val, dst);
  37. }
  38. void DctChannelSelectForward::check_layout_fwd(
  39. const TensorLayout& src, const TensorLayout& mask_offset,
  40. const TensorLayout& mask_val, const TensorLayout& dst) {
  41. const size_t dct_block = param().dct_block_size;
  42. const size_t ih = src.shape[2];
  43. const size_t iw = src.shape[3];
  44. megdnn_assert(
  45. mask_offset.ndim == 0 ||
  46. (mask_offset.ndim == 1 &&
  47. (mask_offset.shape[0] == 0 || mask_offset.shape[0] >= 2) &&
  48. mask_val.ndim == 1),
  49. "mask only support one valid dim");
  50. megdnn_assert(mask_val.ndim <= 1, "only support one dim");
  51. megdnn_assert(src.dtype.enumv() == DTypeEnum::Uint8, "src.dtype == dtype::Uint8");
  52. megdnn_assert(
  53. dst.dtype.enumv() == DTypeEnum::Float32 ||
  54. dst.dtype.enumv() == DTypeEnum::QuantizedS8,
  55. "dst.dtype == dtype::Float32 || dst.dtype.enumv() == "
  56. "DTypeEnum::QuantizedS8");
  57. megdnn_assert(ih % dct_block == 0, "ih mod dctblock == 0");
  58. megdnn_assert(iw % dct_block == 0, "iw mod dctblock == 0");
  59. }
  60. } // namespace megdnn
  61. // vim: syntax=cpp.doxygen