You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gpu_kernel.h 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_
  17. #define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_
  18. #include <cuda.h>
  19. #include <cudnn.h>
  20. #include <string>
  21. #include <vector>
  22. #include "kernel/kernel.h"
  23. #include "kernel/gpu/kernel_constants.h"
  24. #include "device/gpu/gpu_device_manager.h"
  25. #include "device/gpu/gpu_common.h"
  26. #include "session/anf_runtime_algorithm.h"
  27. using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
  28. namespace mindspore {
  29. namespace kernel {
  30. class GpuKernel : public KernelMod {
  31. public:
  32. virtual ~GpuKernel() = default;
  33. virtual bool Init(const CNodePtr &kernel_node) = 0;
  34. protected:
  35. virtual void InitResource() {}
  36. virtual void InitSizeLists() = 0;
  37. template <typename T>
  38. inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) {
  39. if (index >= addr_list.size()) {
  40. MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")";
  41. }
  42. // Kernels may run normally without workspace, the addr_list[index] maybe nullptr.
  43. if ((addr_list[index] == nullptr) || (addr_list[index]->size == 0)) {
  44. return nullptr;
  45. }
  46. MS_EXCEPTION_IF_NULL(addr_list[index]->addr);
  47. return reinterpret_cast<T *>(addr_list[index]->addr);
  48. }
  49. template <typename T>
  50. inline T GetAttr(const CNodePtr &kernel_node, const std::string &key) const {
  51. const PrimitivePtr &prim = AnfAlgo::GetCNodePrimitive(kernel_node);
  52. const ValuePtr &attr = prim->GetAttr(key);
  53. if (attr == nullptr) {
  54. const std::string &prim_name = AnfAlgo::GetCNodeName(kernel_node);
  55. MS_LOG(EXCEPTION) << "The attr(" << key << ") of kernel(" << prim_name << ") not exist";
  56. }
  57. return GetValue<T>(attr);
  58. }
  59. // expand Nd Shape to 4d (N in [0,4])
  60. void ShapeNdTo4d(const std::vector<size_t> &src, std::vector<int> *dst) {
  61. if (src.size() > 4) {
  62. MS_EXCEPTION(ValueError) << src.size() << "-D data is not supported!";
  63. }
  64. dst->push_back(src.size() < 4 ? 1 : SizeToInt(src[src.size() - 4]));
  65. dst->push_back(src.size() < 3 ? 1 : SizeToInt(src[src.size() - 3]));
  66. dst->push_back(src.size() < 2 ? 1 : SizeToInt(src[src.size() - 2]));
  67. dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1]));
  68. }
  69. inline void CheckBroadcast4TensorOp(const std::vector<int> &A, const std::vector<int> &B,
  70. const std::vector<int> &Out) {
  71. if (A != Out && B != Out) {
  72. MS_EXCEPTION(ValueError)
  73. << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n"
  74. "InputA must match the corresponding dimension of the destination tensor outC, and each "
  75. "dimension of the inputB "
  76. "must match the corresponding dimension of outC or must be equal to 1.";
  77. }
  78. }
  79. // choose the suitable datatype for cudnn/cublas
  80. inline cudnnDataType_t GetCudnnDataType(const std::string &Type) {
  81. auto type = kCudnnDtypeMap.find(Type);
  82. if (type == kCudnnDtypeMap.end()) {
  83. MS_EXCEPTION(TypeError) << Type << " is not supported.";
  84. }
  85. return type->second;
  86. }
  87. inline cudaDataType_t GetCudaDataType(const std::string &Type) {
  88. auto type = kCudaDtypeMap.find(Type);
  89. if (type == kCudaDtypeMap.end()) {
  90. MS_EXCEPTION(TypeError) << Type << " is not supported.";
  91. }
  92. return type->second;
  93. }
  94. };
  95. } // namespace kernel
  96. } // namespace mindspore
  97. #endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_