| @@ -37,3 +37,4 @@ template void CalReLUGrad(int size, int8_t *dy, int8_t *y, int8_t *dx, cudaStrea | |||||
| template void CalReLUGrad(int size, int16_t *dy, int16_t *y, int16_t *dx, cudaStream_t cuda_stream); | template void CalReLUGrad(int size, int16_t *dy, int16_t *y, int16_t *dx, cudaStream_t cuda_stream); | ||||
| template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream); | template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream); | ||||
| template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream); | template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream); | ||||
| template void CalReLUGrad(int size, uint8_t *dy, uint8_t *y, uint8_t *dx, cudaStream_t cuda_stream); | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -38,6 +38,7 @@ template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStr | |||||
| template void CalReLU(int size, int16_t *input_addr, int16_t *output_addr, cudaStream_t cuda_stream); | template void CalReLU(int size, int16_t *input_addr, int16_t *output_addr, cudaStream_t cuda_stream); | ||||
| template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); | template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); | ||||
| template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); | template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); | ||||
| template void CalReLU(int size, uint8_t *input_addr, uint8_t *output_addr, cudaStream_t cuda_stream); | |||||
| template <typename T> | template <typename T> | ||||
| __global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) { | __global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) { | ||||
| @@ -78,6 +79,7 @@ template void ReluV2(const size_t num, const int8_t *x, int8_t *y, uint32_t *mas | |||||
| template void ReluV2(const size_t num, const int16_t *x, int16_t *y, uint32_t *mask, cudaStream_t cuda_stream); | template void ReluV2(const size_t num, const int16_t *x, int16_t *y, uint32_t *mask, cudaStream_t cuda_stream); | ||||
| template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream); | template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream); | ||||
| template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream); | template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream); | ||||
| template void ReluV2(const size_t num, const uint8_t *x, uint8_t *y, uint32_t *mask, cudaStream_t cuda_stream); | |||||
| template void ReluGradV2(const size_t num, const double *dy, const uint32_t *mask, double *dx, | template void ReluGradV2(const size_t num, const double *dy, const uint32_t *mask, double *dx, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| @@ -91,3 +93,5 @@ template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *ma | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx, | template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -32,5 +32,7 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutpu | |||||
| ReLUGpuFwdKernel, int16_t) | ReLUGpuFwdKernel, int16_t) | ||||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReLUGpuFwdKernel, | MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReLUGpuFwdKernel, | ||||
| int8_t) | int8_t) | ||||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||||
| ReLUGpuFwdKernel, uint8_t) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,5 +42,8 @@ MS_REG_GPU_KERNEL_ONE( | |||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | ||||
| ReluGradGpuFwdKernel, int8_t) | ReluGradGpuFwdKernel, int8_t) | ||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReluGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||||
| ReluGradGpuFwdKernel, uint8_t) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -45,5 +45,9 @@ MS_REG_GPU_KERNEL_ONE( | |||||
| ReluGradV2, | ReluGradV2, | ||||
| KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), | KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), | ||||
| ReluGradV2GpuKernel, int64_t) | ReluGradV2GpuKernel, int64_t) | ||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReluGradV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), | |||||
| ReluGradV2GpuKernel, uint8_t) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -42,5 +42,8 @@ MS_REG_GPU_KERNEL_ONE( | |||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), | ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), | ||||
| ReluV2GpuKernel, int64_t) | ReluV2GpuKernel, int64_t) | ||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReLUV2, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), | |||||
| ReluV2GpuKernel, uint8_t) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -416,13 +416,13 @@ class ReLUV2(PrimitiveWithInfer): | |||||
| f"but got a {len(input_shape)}-D tensor whose shape is {input_shape}") | f"but got a {len(input_shape)}-D tensor whose shape is {input_shape}") | ||||
| for i in enumerate(input_shape): | for i in enumerate(input_shape): | ||||
| if i[0] == 1: | if i[0] == 1: | ||||
| if input_dtype == mstype.uint8 and input_dtype == mstype.int8: | |||||
| if input_dtype in (mstype.uint8, mstype.int8): | |||||
| mask_shape.append((input_shape[1] + 31) // 32) | mask_shape.append((input_shape[1] + 31) // 32) | ||||
| else: | else: | ||||
| mask_shape.append((input_shape[1] + 15) // 16) | mask_shape.append((input_shape[1] + 15) // 16) | ||||
| else: | else: | ||||
| mask_shape.append(i[1]) | mask_shape.append(i[1]) | ||||
| if input_dtype == mstype.uint8 and input_dtype == mstype.int8: | |||||
| if input_dtype in (mstype.uint8, mstype.int8): | |||||
| mask_shape.append(4) | mask_shape.append(4) | ||||
| else: | else: | ||||
| mask_shape.append(2) | mask_shape.append(2) | ||||