| @@ -45,36 +45,36 @@ class OneHotGpuFwdKernel : public GpuKernel { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| int axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis")); | |||||
| auto input = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| auto output = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| int input_size = SizeToInt(input.size()); | |||||
| const int default_axis = -1; | |||||
| int64_t axis = GetAttr<int64_t>(kernel_node, "axis"); | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| int64_t input_dims = static_cast<int64_t>(input_shape.size()); | |||||
| if (axis >= input_dims) { | |||||
| MS_LOG(ERROR) << "invalid one hot axis value: " << axis << " for input dims size: " << input_shape.size(); | |||||
| return false; | |||||
| } | |||||
| const int64_t default_axis = -1; | |||||
| // Compress arbitrary tensor dimensions into three dimensions (left_dims, depth, right_dims). | // Compress arbitrary tensor dimensions into three dimensions (left_dims, depth, right_dims). | ||||
| for (int i = 0; i < input_size; i++) { | |||||
| auto dim_size = input[IntToSize(i)]; | |||||
| if (axis == default_axis || i < axis) { | |||||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||||
| auto dim_size = input_shape[i]; | |||||
| if (axis == default_axis || i < IntToSize(axis)) { | |||||
| left_dim_size_ *= dim_size; | left_dim_size_ *= dim_size; | ||||
| } | } | ||||
| if (axis != default_axis && i >= axis) { | |||||
| if (axis != default_axis && i >= IntToSize(axis)) { | |||||
| right_dim_size_ *= dim_size; | right_dim_size_ *= dim_size; | ||||
| } | } | ||||
| } | } | ||||
| for (auto size : input) { | |||||
| for (auto size : input_shape) { | |||||
| input_size_ *= size; | input_size_ *= size; | ||||
| } | } | ||||
| for (auto size : output) { | |||||
| for (auto size : output_shape) { | |||||
| output_size_ *= size; | output_size_ *= size; | ||||
| } | } | ||||
| if (axis >= input_size) { | |||||
| MS_LOG(ERROR) << "invalid one hot axis value: " << axis << " for input dims size: " << input.size(); | |||||
| return false; | |||||
| } | |||||
| if (axis == default_axis) { | if (axis == default_axis) { | ||||
| depth_ = output[output.size() - 1]; | |||||
| depth_ = output_shape[output_shape.size() - 1]; | |||||
| } else { | } else { | ||||
| depth_ = output[IntToSize(axis)]; | |||||
| depth_ = output_shape[IntToSize(axis)]; | |||||
| } | } | ||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| @@ -55,11 +55,10 @@ class OnesLikeGpuKernel : public GpuKernel { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| size_t shape_size = input_shape.size(); | size_t shape_size = input_shape.size(); | ||||
| input_size_ = 1; | |||||
| input_size_ = sizeof(T); | |||||
| for (size_t i = 0; i < shape_size; i++) { | for (size_t i = 0; i < shape_size; i++) { | ||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| } | } | ||||
| input_size_ *= sizeof(T); | |||||
| output_size_ = input_size_; | output_size_ = input_size_; | ||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| @@ -21,14 +21,14 @@ __global__ void OneHotKernel(size_t size, const S *indices, size_t depth, const | |||||
| size_t left_dim_size, size_t right_dim_size, T *output) { | size_t left_dim_size, size_t right_dim_size, T *output) { | ||||
| T on_v = *on_value; | T on_v = *on_value; | ||||
| T off_v = *off_value; | T off_v = *off_value; | ||||
| for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; | |||||
| for (size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; | |||||
| thread_idx += blockDim.x * gridDim.x) { | thread_idx += blockDim.x * gridDim.x) { | ||||
| if (thread_idx < size) { | if (thread_idx < size) { | ||||
| int left_idx = (thread_idx / (depth * right_dim_size)) % left_dim_size; | |||||
| int d_idx = thread_idx / right_dim_size % depth; | |||||
| int right_idx = thread_idx % right_dim_size; | |||||
| int input_idx = left_idx * right_dim_size + right_idx; | |||||
| int output_idx = left_idx * depth * right_dim_size + d_idx * right_dim_size + right_idx; | |||||
| size_t left_idx = (thread_idx / (depth * right_dim_size)) % left_dim_size; | |||||
| size_t d_idx = thread_idx / right_dim_size % depth; | |||||
| size_t right_idx = thread_idx % right_dim_size; | |||||
| size_t input_idx = left_idx * right_dim_size + right_idx; | |||||
| size_t output_idx = left_idx * depth * right_dim_size + d_idx * right_dim_size + right_idx; | |||||
| if (indices[input_idx] == d_idx) { | if (indices[input_idx] == d_idx) { | ||||
| output[output_idx] = on_v; | output[output_idx] = on_v; | ||||
| } else { | } else { | ||||
| @@ -18,20 +18,20 @@ | |||||
| #include "oneslike_impl.cuh" | #include "oneslike_impl.cuh" | ||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void OnesLike(const int size, const T* input, T* output) { | |||||
| __global__ void OnesLike(const size_t size, const T* input, T* output) { | |||||
| int one = 1; | int one = 1; | ||||
| T val = static_cast<T>(one); | T val = static_cast<T>(one); | ||||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| output[pos] = val; | output[pos] = val; | ||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void CalOnesLike(const int size, const T* input, T* output, cudaStream_t cuda_stream) { | |||||
| void CalOnesLike(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) { | |||||
| OnesLike<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output); | OnesLike<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output); | ||||
| return; | return; | ||||
| } | } | ||||
| template void CalOnesLike<float>(const int size, const float* input, float* output, cudaStream_t cuda_stream); | |||||
| template void CalOnesLike<half>(const int size, const half* input, half* output, cudaStream_t cuda_stream); | |||||
| template void CalOnesLike<int>(const int size, const int* input, int* output, cudaStream_t cuda_stream); | |||||
| template void CalOnesLike<float>(const size_t size, const float* input, float* output, cudaStream_t cuda_stream); | |||||
| template void CalOnesLike<half>(const size_t size, const half* input, half* output, cudaStream_t cuda_stream); | |||||
| template void CalOnesLike<int>(const size_t size, const int* input, int* output, cudaStream_t cuda_stream); | |||||
| @@ -18,6 +18,6 @@ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_ | #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_ | ||||
| template <typename T> | template <typename T> | ||||
| void CalOnesLike(const int size, const T* input, T* output, cudaStream_t cuda_stream); | |||||
| void CalOnesLike(const size_t size, const T* input, T* output, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_ | ||||