/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ #include #include "kernel/gpu/gpu_kernel.h" #include "kernel/gpu/gpu_kernel_factory.h" #include "kernel/gpu/kernel_constants.h" #include "kernel/gpu/cuda_impl/gelu_impl.cuh" namespace mindspore { namespace kernel { template class GeluGpuKernel : public GpuKernel { public: GeluGpuKernel() : input_size_(0) {} ~GeluGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } bool Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs, void *stream_ptr) override { T *input_addr = GetDeviceAddress(inputs, 0); T *output_addr = GetDeviceAddress(outputs, 0); Gelu(input_size_ / sizeof(T), input_addr, output_addr, reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { InitResource(); input_size_ = sizeof(T); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); for (auto dim : input_shape) { input_size_ *= dim; } InitSizeLists(); return true; } protected: void InitSizeLists() override { input_size_list_.push_back(input_size_); output_size_list_.push_back(input_size_); } private: std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; size_t input_size_; }; } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_