You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

topk.cpp 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #include "megdnn/oprs/general.h"
  2. #include "src/common/utils.h"
  3. #include <cmath>
  4. using namespace megdnn;
  5. void TopK::deduce_layout(
  6. int k, const TensorLayout& data, TensorLayout& values, TensorLayout& indices) {
  7. megdnn_assert(
  8. k && data.ndim == 2 && data.stride[1] == 1, "invalid k=%d data=%s", k,
  9. data.to_string().c_str());
  10. values.dtype = data.dtype;
  11. indices.dtype = dtype::Int32{};
  12. switch (param().mode) {
  13. case Param::Mode::KTH_ONLY:
  14. values.init_contiguous_stride({data[0]});
  15. indices.ndim = 0;
  16. break;
  17. case Param::Mode::VALUE_IDX_NOSORT:
  18. case Param::Mode::VALUE_IDX_SORTED:
  19. values.init_contiguous_stride(
  20. {data[0], std::min<size_t>(std::abs(k), data.shape[1])});
  21. indices.init_contiguous_stride(values);
  22. break;
  23. default:
  24. megdnn_throw("invalid TopK mode");
  25. }
  26. }
  27. void TopK::exec(
  28. int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
  29. _megdnn_tensor_out indices, _megdnn_workspace workspace) {
  30. TensorLayout oval, oidx;
  31. deduce_layout(k, data.layout, oval, oidx);
  32. megdnn_assert_eq_layout(oval, values.layout);
  33. int32_t* iptr = nullptr;
  34. if (param().mode == Param::Mode::KTH_ONLY) {
  35. megdnn_assert_eq_shape(indices.layout, TensorShape{});
  36. } else {
  37. iptr = indices.ptr<int32_t>();
  38. megdnn_assert_eq_layout(oidx, indices.layout);
  39. }
  40. megdnn_assert(
  41. workspace.size >=
  42. get_workspace_in_bytes(k, data.layout, values.layout, indices.layout));
  43. if (static_cast<size_t>(std::abs(k)) > data.layout[1]) {
  44. if (k > 0) {
  45. k = data.layout[1];
  46. } else {
  47. k = -static_cast<int>(data.layout[1]);
  48. }
  49. }
  50. do_exec(k, data, values, iptr, workspace);
  51. }
  52. // vim: syntax=cpp.doxygen