You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reduce.cpp 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /**
  2. * \file dnn/src/common/reduce.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include <numeric>
  13. #include "src/common/utils.h"
  14. namespace {
  15. using namespace megdnn;
  16. using megdnn::Reduce;
  17. DType get_out_dtype(const Reduce::DataType data_type, const DType inp_dtype) {
  18. if (data_type == Reduce::DataType::FLOAT_O16xC32) {
  19. #if !MEGDNN_DISABLE_FLOAT16
  20. return dtype::Float16();
  21. #else
  22. megdnn_assert_internal(0);
  23. #endif
  24. }
  25. if (data_type == Reduce::DataType::FLOAT_O32xC32) {
  26. return dtype::Float32();
  27. }
  28. if (data_type == Reduce::DataType::QUINT_I8xO32) {
  29. megdnn_assert(inp_dtype.enumv() == DTypeEnum::Quantized8Asymm);
  30. return dtype::QuantizedS32(
  31. inp_dtype.param<dtype::Quantized8Asymm>().scale);
  32. }
  33. if (data_type == Reduce::DataType::QINT_I8xO32) {
  34. megdnn_assert(inp_dtype.enumv() == DTypeEnum::QuantizedS8);
  35. return dtype::QuantizedS32(
  36. inp_dtype.param<dtype::QuantizedS8>().scale);
  37. }
  38. megdnn_assert(data_type == Reduce::DataType::DEFAULT);
  39. return inp_dtype;
  40. }
  41. } // namespace
  42. namespace megdnn {
  43. void ReduceForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  44. megdnn_assert(
  45. param().axis >= 0 && static_cast<uint32_t>(param().axis) < src.ndim,
  46. "axis: %d ndim: %zu", param().axis, src.ndim);
  47. dst = src;
  48. dst.shape[param().axis] = 1;
  49. dst.dtype = get_out_dtype(param().data_type, src.dtype);
  50. dst.format = src.format;
  51. dst.init_contiguous_stride();
  52. }
  53. void ReduceForward::check_exec(const TensorLayout& src, const TensorLayout& dst,
  54. size_t workspace_in_bytes) {
  55. auto errmsg = [&]() {
  56. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst);
  57. };
  58. megdnn_assert(param().data_type != Reduce::DataType::FLOAT_IO16xC32,
  59. "FLOAT_IO16xC32 is deprecated");
  60. MEGDNN_MARK_USED_VAR(errmsg);
  61. megdnn_assert_contiguous(src);
  62. megdnn_assert_contiguous(dst);
  63. megdnn_assert(src.ndim == dst.ndim, "%s", errmsg().c_str());
  64. megdnn_assert(param().axis >= 0);
  65. uint32_t axis = param().axis;
  66. megdnn_assert(axis < src.ndim, "%s", errmsg().c_str());
  67. rep(i, src.ndim) {
  68. if (i != axis) {
  69. megdnn_assert(src.shape[i] == dst.shape[i], "%s", errmsg().c_str());
  70. } else {
  71. megdnn_assert(dst.shape[i] == 1_z, "%s", errmsg().c_str());
  72. }
  73. }
  74. megdnn_assert(src.dtype.category() == dst.dtype.category() ||
  75. param().data_type == Reduce::DataType::FLOAT_O32xC32,
  76. "the category of reduce output and input must be the same,"
  77. " or the data_type is FLOAT_O32xC32");
  78. if (param().data_type == DataType::DEFAULT) {
  79. megdnn_assert(src.dtype == dst.dtype &&
  80. (src.dtype.category() == DTypeCategory::FLOAT ||
  81. src.dtype.category() == DTypeCategory::INT ||
  82. src.dtype.category() == DTypeCategory::QUANTIZED));
  83. } else if (param().data_type == DataType::QUINT_I8xO32) {
  84. megdnn_assert(src.dtype.enumv() == DTypeEnum::Quantized8Asymm);
  85. } else if (param().data_type == DataType::QINT_I8xO32) {
  86. megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8);
  87. } else if (param().data_type == DataType::FLOAT_IO16xC32 ||
  88. param().data_type == DataType::FLOAT_O16xC32) {
  89. megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT);
  90. } else {
  91. megdnn_assert(param().data_type == DataType::FLOAT_O32xC32);
  92. }
  93. auto expected = get_out_dtype(param().data_type, src.dtype);
  94. megdnn_assert(expected == dst.dtype);
  95. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  96. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  97. }
  98. } // namespace megdnn
  99. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台