You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bfloat16.cpp 3.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * \file dnn/src/cuda/matrix_mul/bfloat16.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/cuda/handle.h"
  12. #include "src/cuda/matrix_mul/algos.h"
  13. #include "src/cuda/utils.h"
  14. using namespace megdnn;
  15. using namespace cuda;
  16. MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16(
  17. MatrixMulForwardImpl::AlgoBase* algorithm)
  18. : m_algorithm(algorithm) {
  19. megdnn_assert_internal(algorithm);
  20. m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name());
  21. }
  22. MatrixMulForwardImpl::AlgoBase::SizeArgs
  23. MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const {
  24. auto new_args = args;
  25. auto change_dtype = [](TensorLayout& layout) {
  26. if (layout.dtype == dtype::BFloat16()) {
  27. layout.dtype = dtype::Float32();
  28. }
  29. };
  30. change_dtype(new_args.layout_a);
  31. change_dtype(new_args.layout_b);
  32. change_dtype(new_args.layout_c);
  33. return new_args;
  34. }
  35. bool MatrixMulForwardImpl::AlgoBFloat16::is_available(
  36. const SizeArgs& args) const {
  37. auto fargs = float_args(args);
  38. return args.layout_a.dtype == dtype::BFloat16() &&
  39. m_algorithm->is_available(fargs);
  40. }
  41. WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle(
  42. void* ptr, const SizeArgs& args) const {
  43. auto fargs = float_args(args);
  44. SmallVector<size_t> sizes;
  45. auto get_workspace = [&sizes](const TensorLayout& src) {
  46. TensorLayout dst = src;
  47. if (dst.dtype == dtype::BFloat16()) {
  48. dst.dtype = dtype::Float32();
  49. sizes.push_back(dst.span().dist_byte());
  50. }
  51. };
  52. get_workspace(args.layout_a);
  53. get_workspace(args.layout_b);
  54. get_workspace(args.layout_c);
  55. sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs));
  56. return {ptr, std::move(sizes)};
  57. }
  58. size_t MatrixMulForwardImpl::AlgoBFloat16::get_workspace_in_bytes(
  59. const SizeArgs& args) const {
  60. return get_workspace_bundle(nullptr, args).total_size_in_bytes();
  61. }
  62. void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
  63. TensorND a = args.tensor_a;
  64. TensorND b = args.tensor_b;
  65. TensorND c = args.tensor_c;
  66. auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
  67. auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
  68. args.opr->handle(), &bundle);
  69. ctypecvt.src_to_comp_type(args.tensor_a, a)
  70. .src_to_comp_type(args.tensor_b, b)
  71. .src_to_comp_type(args.tensor_c, c);
  72. {
  73. auto matmul_opr =
  74. args.opr->handle()->create_operator<MatrixMulForward>();
  75. matmul_opr->param() = args.opr->param();
  76. matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
  77. matmul_opr->execution_policy() = {m_algorithm};
  78. matmul_opr->exec(a, b, c, ctypecvt.workspace());
  79. }
  80. ctypecvt.comp_to_dst_type(c, args.tensor_c);
  81. }
  82. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台