|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -38,6 +38,7 @@ template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStr |
|
|
|
template void CalReLU(int size, int16_t *input_addr, int16_t *output_addr, cudaStream_t cuda_stream); |
|
|
|
template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); |
|
|
|
template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); |
|
|
|
template void CalReLU(int size, uint8_t *input_addr, uint8_t *output_addr, cudaStream_t cuda_stream); |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
__global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) { |
|
|
|
@@ -78,6 +79,7 @@ template void ReluV2(const size_t num, const int8_t *x, int8_t *y, uint32_t *mas |
|
|
|
template void ReluV2(const size_t num, const int16_t *x, int16_t *y, uint32_t *mask, cudaStream_t cuda_stream); |
|
|
|
template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream); |
|
|
|
template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream); |
|
|
|
template void ReluV2(const size_t num, const uint8_t *x, uint8_t *y, uint32_t *mask, cudaStream_t cuda_stream); |
|
|
|
|
|
|
|
template void ReluGradV2(const size_t num, const double *dy, const uint32_t *mask, double *dx, |
|
|
|
cudaStream_t cuda_stream); |
|
|
|
@@ -91,3 +93,5 @@ template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *ma |
|
|
|
cudaStream_t cuda_stream); |
|
|
|
template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx, |
|
|
|
cudaStream_t cuda_stream); |
|
|
|
template void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx, |
|
|
|
cudaStream_t cuda_stream); |