You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

handle.cpp 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. /**
  2. * \file dnn/src/rocm/handle.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "hcc_detail/hcc_defs_prologue.h"
  12. #include "src/common/handle_impl.h"
  13. #include "src/common/version_symbol.h"
  14. #include "src/rocm/handle.h"
  15. #include "src/rocm/miopen_with_check.h"
  16. #include "src/rocm/utils.h"
  17. #include "src/rocm/adaptive_pooling/opr_impl.h"
  18. #include "src/rocm/add_update/opr_impl.h"
  19. #include "src/rocm/argmxx/opr_impl.h"
  20. #include "src/rocm/argsort/opr_impl.h"
  21. #include "src/rocm/batch_normalization/opr_impl.h"
  22. #include "src/rocm/batched_matrix_mul/opr_impl.h"
  23. #include "src/rocm/checksum/opr_impl.h"
  24. #include "src/rocm/convolution/opr_impl.h"
  25. #include "src/rocm/elemwise/opr_impl.h"
  26. #include "src/rocm/eye/opr_impl.h"
  27. #include "src/rocm/fill/opr_impl.h"
  28. #include "src/rocm/indexing_multi_axis_vec/opr_impl.h"
  29. #include "src/rocm/indexing_one_hot/opr_impl.h"
  30. #include "src/rocm/linspace/opr_impl.h"
  31. #include "src/rocm/matrix_mul/opr_impl.h"
  32. #include "src/rocm/param_pack/opr_impl.h"
  33. #include "src/rocm/pooling/opr_impl.h"
  34. #include "src/rocm/powc/opr_impl.h"
  35. #include "src/rocm/reduce/opr_impl.h"
  36. #include "src/rocm/relayout/opr_impl.h"
  37. #include "src/rocm/rng/opr_impl.h"
  38. #include "src/rocm/sleep/opr_impl.h"
  39. #include "src/rocm/topk/opr_impl.h"
  40. #include "src/rocm/type_cvt/opr_impl.h"
  41. #include <hip/hip_version.h>
  42. #include <miopen/version.h>
  43. #include <cstring>
  44. #define STR_HELPER(x) #x
  45. #define STR(x) STR_HELPER(x)
  46. #define MIOPEN_VERSION_STR \
  47. STR(MIOPEN_VERSION_MAJOR) \
  48. "." STR(MIOPEN_VERSION_MINOR) "." STR(MIOPEN_VERSION_PATCH)
  49. #pragma message "compile with MIOpen " MIOPEN_VERSION_STR " "
  50. #undef STR
  51. #undef STR_HELPER
  52. namespace megdnn {
  53. std::unique_ptr<Handle> Handle::make_rocm_handle(
  54. megcoreComputingHandle_t computing_handle) {
  55. return std::make_unique<rocm::HandleImpl>(computing_handle);
  56. }
  57. template <typename Opr>
  58. std::unique_ptr<Opr> Handle::create_rocm_operator() {
  59. return static_cast<rocm::HandleImpl*>(this)->create_operator<Opr>();
  60. }
  61. #define INST(opr) template std::unique_ptr<opr> Handle::create_rocm_operator();
  62. MEGDNN_FOREACH_OPR_CLASS(INST)
  63. #undef INST
  64. } // namespace megdnn
  65. namespace megdnn {
  66. namespace rocm {
  67. HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
  68. : HandleImplHelper(comp_handle, HandleType::ROCM) {
  69. // Get megcore device handle
  70. megcoreDeviceHandle_t dev_handle;
  71. megcoreGetDeviceHandle(comp_handle, &dev_handle);
  72. int dev_id;
  73. megcoreGetDeviceID(dev_handle, &dev_id);
  74. if (dev_id < 0) {
  75. hip_check(hipGetDevice(&dev_id));
  76. }
  77. m_device_id = dev_id;
  78. hip_check(hipGetDeviceProperties(&m_device_prop, dev_id));
  79. // Get stream from MegCore computing handle.
  80. //! no version check
  81. megcore::getROCMContext(comp_handle, &m_megcore_context);
  82. rocblas_check(rocblas_create_handle(&m_rocblas_handle));
  83. //! must call miopenCreateWithStream() to create miopen handle, then the
  84. //! rocblas_handle of miopen will set to be the same stream , otherwise
  85. //! miopen create rocblas_handle with default stream
  86. miopen_check(miopenCreateWithStream(&m_miopen_handle, stream()));
  87. // Set stream for miopen and rocblas handles.
  88. rocblas_check(rocblas_set_stream(m_rocblas_handle, stream()));
  89. // Note that all rocblas scalars (alpha, beta) and scalar results such as
  90. // dot output resides at device side.
  91. rocblas_check(
  92. rocblas_set_pointer_mode(m_rocblas_handle, rocblas_pointer_mode_device));
  93. // init const scalars
  94. hip_check(hipMalloc(&m_const_scalars, sizeof(ConstScalars)));
  95. ConstScalars const_scalars_val;
  96. const_scalars_val.init();
  97. hip_check(hipMemcpyAsync(
  98. m_const_scalars, &const_scalars_val, sizeof(ConstScalars),
  99. hipMemcpyHostToDevice, stream()));
  100. hip_check(hipStreamSynchronize(stream()));
  101. }
  102. HandleImpl::~HandleImpl() noexcept {
  103. miopen_check(miopenDestroy(m_miopen_handle));
  104. rocblas_check(rocblas_destroy_handle(m_rocblas_handle));
  105. hip_check(hipFree(m_const_scalars));
  106. }
  107. void HandleImpl::ConstScalars::init() {
  108. #if !MEGDNN_DISABLE_FLOAT16
  109. f16[0].megdnn_x = 0;
  110. f16[1].megdnn_x = 1;
  111. #endif
  112. f32[0] = 0;
  113. f32[1] = 1;
  114. i32[0] = 0;
  115. i32[1] = 1;
  116. }
  117. template <typename Opr>
  118. std::unique_ptr<Opr> HandleImpl::create_operator() {
  119. megdnn_throw("unsupported rocm opr");
  120. return nullptr;
  121. }
  122. size_t HandleImpl::alignment_requirement() const {
  123. auto&& prop = m_device_prop;
  124. MEGDNN_MARK_USED_VAR(prop);
  125. //! for now, texture functions are not supported.
  126. return 1u;
  127. }
  128. bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) {
  129. // is contiguous or can be hold by
  130. // relayout::param::try_copy_2d/try_copy_last_contig
  131. return src.is_contiguous() || src.stride[src.ndim - 1] == 1;
  132. }
  133. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortForward);
  134. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortBackward);
  135. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward);
  136. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData);
  137. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter);
  138. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseForward);
  139. MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye);
  140. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward);
  141. MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward);
  142. MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward);
  143. MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingForward);
  144. MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingBackward);
  145. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward);
  146. MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt);
  147. MEGDNN_SPECIALIZE_CREATE_OPERATOR(TopK);
  148. MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdateForward);
  149. MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMulForward);
  150. MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward);
  151. MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingOneHotForward);
  152. MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetOneHotForward);
  153. MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG);
  154. MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG);
  155. MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward);
  156. MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC);
  157. MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingMultiAxisVec);
  158. MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetMultiAxisVec);
  159. MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingIncrMultiAxisVec);
  160. MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace);
  161. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward);
  162. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward);
  163. MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward);
  164. MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward);
  165. MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward);
  166. MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat);
  167. MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill);
  168. #pragma GCC diagnostic push
  169. #pragma GCC diagnostic ignored "-Wpragmas"
  170. #pragma GCC diagnostic ignored "-Winstantiation-after-specialization"
  171. MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
  172. #pragma GCC diagnostic pop
  173. } // namespace rocm
  174. } // namespace megdnn
  175. MEGDNN_VERSION_SYMBOL3(HIP, HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH);
  176. MEGDNN_VERSION_SYMBOL3(
  177. MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MIOPEN_VERSION_PATCH);
  178. // vim: syntax=cpp.doxygen