|
|
|
@@ -199,6 +199,10 @@ template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const half *x |
|
|
|
cudaStream_t stream); |
|
|
|
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, bool *y, |
|
|
|
cudaStream_t stream); |
|
|
|
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int8_t *x0, const int8_t *x1, bool *y, |
|
|
|
cudaStream_t stream); |
|
|
|
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, bool *y, |
|
|
|
cudaStream_t stream); |
|
|
|
|
|
|
|
// Element-wise ArithMetic |
|
|
|
template <typename T, typename Func> |
|
|
|
@@ -261,6 +265,10 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const half |
|
|
|
cudaStream_t stream); |
|
|
|
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, int *y, |
|
|
|
cudaStream_t stream); |
|
|
|
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int8_t *x0, const int8_t *x1, int8_t *y, |
|
|
|
cudaStream_t stream); |
|
|
|
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, uint8_t *y, |
|
|
|
cudaStream_t stream); |
|
|
|
|
|
|
|
// Broadcast comparation |
|
|
|
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } |
|
|
|
@@ -333,6 +341,12 @@ template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector |
|
|
|
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, |
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, |
|
|
|
bool *y, cudaStream_t stream); |
|
|
|
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, |
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int8_t *x0, |
|
|
|
const int8_t *x1, bool *y, cudaStream_t stream); |
|
|
|
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, |
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint8_t *x0, |
|
|
|
const uint8_t *x1, bool *y, cudaStream_t stream); |
|
|
|
|
|
|
|
// Broadcast Arithmetic |
|
|
|
template <typename T, typename Func> |
|
|
|
@@ -448,6 +462,12 @@ template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vect |
|
|
|
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, |
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, |
|
|
|
int *y, cudaStream_t stream); |
|
|
|
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, |
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int8_t *x0, |
|
|
|
const int8_t *x1, int8_t *y, cudaStream_t stream); |
|
|
|
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, |
|
|
|
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint8_t *x0, |
|
|
|
const uint8_t *x1, uint8_t *y, cudaStream_t stream); |
|
|
|
|
|
|
|
// BroadcastTo |
|
|
|
template <typename T> |
|
|
|
|