You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gpu_kernel.h 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_
  17. #define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_
  18. #include <cuda.h>
  19. #include <cudnn.h>
  20. #include <string>
  21. #include <vector>
  22. #include "kernel/kernel.h"
  23. #include "device/gpu/gpu_device_manager.h"
  24. #include "device/gpu/gpu_common.h"
  25. #include "session/anf_runtime_algorithm.h"
  26. using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
  27. namespace mindspore {
  28. namespace kernel {
  29. class GpuKernel : public KernelMod {
  30. public:
  31. virtual ~GpuKernel() = default;
  32. virtual bool Init(const CNodePtr &kernel_node) = 0;
  33. protected:
  34. virtual void InitResource() {}
  35. virtual void InitSizeLists() = 0;
  36. template <typename T>
  37. inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) {
  38. if (index >= addr_list.size()) {
  39. MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")";
  40. }
  41. // Kernels may run normally without workspace, the addr_list[index] maybe nullptr.
  42. if ((addr_list[index] == nullptr) || (addr_list[index]->size == 0)) {
  43. return nullptr;
  44. }
  45. MS_EXCEPTION_IF_NULL(addr_list[index]->addr);
  46. return reinterpret_cast<T *>(addr_list[index]->addr);
  47. }
  48. template <typename T>
  49. inline T GetAttr(const CNodePtr &kernel_node, const std::string &key) const {
  50. const PrimitivePtr &prim = AnfAlgo::GetCNodePrimitive(kernel_node);
  51. const ValuePtr &attr = prim->GetAttr(key);
  52. if (attr == nullptr) {
  53. const std::string &prim_name = AnfAlgo::GetCNodeName(kernel_node);
  54. MS_LOG(EXCEPTION) << "The attr(" << key << ") of kernel(" << prim_name << ") not exist";
  55. }
  56. return GetValue<T>(attr);
  57. }
  58. };
  59. } // namespace kernel
  60. } // namespace mindspore
  61. #endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_