GitOrigin-RevId: 3d2c315a36
tags/v1.6.0
| @@ -174,7 +174,7 @@ template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \ | |||||
| ARGSORT_FOREACH_CTYPE(INST_FORWARD) | ARGSORT_FOREACH_CTYPE(INST_FORWARD) | ||||
| INST_CUB_SORT(uint32_t) | INST_CUB_SORT(uint32_t) | ||||
| // INST_CUB_SORT(uint64_t) | |||||
| INST_CUB_SORT(uint64_t) | |||||
| #undef INST_CUB_SORT | #undef INST_CUB_SORT | ||||
| #undef INST_FORWARD | #undef INST_FORWARD | ||||
| } | } | ||||
| @@ -40,6 +40,7 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, | |||||
| const int* iptr_src = NULL); | const int* iptr_src = NULL); | ||||
| //! iterate over all supported data types | //! iterate over all supported data types | ||||
| // device_radix_sort does not support dt_float16 dtype(half_float::half in rocm) | |||||
| #define ARGSORT_FOREACH_CTYPE(cb) \ | #define ARGSORT_FOREACH_CTYPE(cb) \ | ||||
| cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) | cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) | ||||
| @@ -14,8 +14,6 @@ | |||||
| #include "./argsort.h.hip" | #include "./argsort.h.hip" | ||||
| #include "./backward.h.hip" | #include "./backward.h.hip" | ||||
| // #include "src/rocm/utils.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace rocm; | using namespace rocm; | ||||
| using namespace argsort; | using namespace argsort; | ||||
| @@ -11,13 +11,9 @@ | |||||
| #include "hcc_detail/hcc_defs_prologue.h" | #include "hcc_detail/hcc_defs_prologue.h" | ||||
| #include "./bitonic_sort.h.hip" | #include "./bitonic_sort.h.hip" | ||||
| // #include "src/cuda/query_blocksize.cuh" | |||||
| // #include "megdnn/dtype.h" | |||||
| #include "megdnn/dtype.h" | |||||
| // #if __CUDACC_VER_MAJOR__ < 9 | |||||
| // #pragma message "warp sync disabled due to insufficient cuda version" | |||||
| #define __syncwarp __syncthreads | #define __syncwarp __syncthreads | ||||
| // #endif | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <cmath> | #include <cmath> | ||||
| @@ -84,17 +80,17 @@ struct NumTrait<int32_t> { | |||||
| static __device__ __forceinline__ int32_t min() { return INT_MIN; } | static __device__ __forceinline__ int32_t min() { return INT_MIN; } | ||||
| }; | }; | ||||
| // #if !MEGDNN_DISABLE_FLOAT16 | |||||
| // template <> | |||||
| // struct NumTrait<dt_float16> { | |||||
| // static __device__ __forceinline__ dt_float16 max() { | |||||
| // return std::numeric_limits<dt_float16>::max(); | |||||
| // } | |||||
| // static __device__ __forceinline__ dt_float16 min() { | |||||
| // return std::numeric_limits<dt_float16>::lowest(); | |||||
| // } | |||||
| // }; | |||||
| // #endif | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| template <> | |||||
| struct NumTrait<dt_float16> { | |||||
| static __device__ __forceinline__ dt_float16 max() { | |||||
| return std::numeric_limits<dt_float16>::max(); | |||||
| } | |||||
| static __device__ __forceinline__ dt_float16 min() { | |||||
| return std::numeric_limits<dt_float16>::lowest(); | |||||
| } | |||||
| }; | |||||
| #endif | |||||
| struct LessThan { | struct LessThan { | ||||
| template <typename Key, typename Value> | template <typename Key, typename Value> | ||||
| @@ -310,7 +306,7 @@ namespace rocm { | |||||
| INST(float, int); | INST(float, int); | ||||
| INST(int32_t, int); | INST(int32_t, int); | ||||
| // DNN_INC_FLOAT16(INST(dt_float16, int)); | |||||
| DNN_INC_FLOAT16(INST(dt_float16, int)); | |||||
| #undef INST | #undef INST | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -18,13 +18,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <cmath> | #include <cmath> | ||||
| #if __CUDACC_VER_MAJOR__ < 9 | |||||
| #pragma message "topk is a little slower on cuda earlier than 9.0" | |||||
| // on cuda 9.0 and later, due to thread-divergent branches we should use | |||||
| // __syncwarp; and I am too lazy to implement a correct legacy version, so just | |||||
| // use __syncthreads instead for older cuda | |||||
| #define __syncwarp __syncthreads | #define __syncwarp __syncthreads | ||||
| #endif | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace rocm; | using namespace rocm; | ||||
| @@ -256,12 +250,12 @@ static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt, | |||||
| } | } | ||||
| } | } | ||||
| //if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | | |||||
| // (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { | |||||
| // // impossible | |||||
| // int* bad = 0x0; | |||||
| // *bad = 23; | |||||
| //} | |||||
| if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | | |||||
| (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { | |||||
| // impossible | |||||
| int* bad = 0x0; | |||||
| *bad = 23; | |||||
| } | |||||
| } | } | ||||
| static uint32_t get_grid_dim_x(uint32_t length) { | static uint32_t get_grid_dim_x(uint32_t length) { | ||||