|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/**
|
|
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
@@ -18,7 +18,7 @@ |
|
|
|
#include "runtime/device/gpu/cuda_common.h"
|
|
|
|
#include "include/cuda_fp16.h"
|
|
|
|
template <typename T>
|
|
|
|
__global__ void Argmax1D(const T* input, const int channel_size, int* output) {
|
|
|
|
__global__ void Argmax1D(const T *input, const int channel_size, int *output) {
|
|
|
|
int max_index = 0;
|
|
|
|
T max = input[0];
|
|
|
|
for (int pos = 1; pos < channel_size; pos++) {
|
|
|
|
@@ -31,7 +31,7 @@ __global__ void Argmax1D(const T* input, const int channel_size, int* output) { |
|
|
|
return;
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
|
|
__global__ void ArgmaxDefault2D(const T* input, const int batch_size, const int channel_size, int* output) {
|
|
|
|
__global__ void ArgmaxDefault2D(const T *input, const int batch_size, const int channel_size, int *output) {
|
|
|
|
int pos;
|
|
|
|
int max_index;
|
|
|
|
T max;
|
|
|
|
@@ -51,7 +51,7 @@ __global__ void ArgmaxDefault2D(const T* input, const int batch_size, const int |
|
|
|
return;
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
|
|
__global__ void ArgmaxAxis2D(const T* input, const int batch_size, const int channel_size, int* output) {
|
|
|
|
__global__ void ArgmaxAxis2D(const T *input, const int batch_size, const int channel_size, int *output) {
|
|
|
|
int pos;
|
|
|
|
int max_index;
|
|
|
|
T max;
|
|
|
|
@@ -70,7 +70,7 @@ __global__ void ArgmaxAxis2D(const T* input, const int batch_size, const int cha |
|
|
|
return;
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
|
|
void CalArgmax(const T* input, const int batch_size, const int channel_size, const int axis, int* output,
|
|
|
|
void CalArgmax(const T *input, const int batch_size, const int channel_size, const int64_t axis, int *output,
|
|
|
|
cudaStream_t cuda_stream) {
|
|
|
|
if (batch_size == 0) {
|
|
|
|
Argmax1D<<<1, 1, 0, cuda_stream>>>(input, channel_size, output);
|
|
|
|
@@ -82,7 +82,7 @@ void CalArgmax(const T* input, const int batch_size, const int channel_size, con |
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
template void CalArgmax<float>(const float* input, const int batch_size, const int channel_size, const int axis,
|
|
|
|
int* output, cudaStream_t cuda_stream);
|
|
|
|
template void CalArgmax<half>(const half* input, const int batch_size, const int channel_size, const int axis,
|
|
|
|
int* output, cudaStream_t cuda_stream);
|
|
|
|
template void CalArgmax<float>(const float *input, const int batch_size, const int channel_size, const int64_t axis,
|
|
|
|
int *output, cudaStream_t cuda_stream);
|
|
|
|
template void CalArgmax<half>(const half *input, const int batch_size, const int channel_size, const int64_t axis,
|
|
|
|
int *output, cudaStream_t cuda_stream);
|