/** * \file dnn/src/cuda/matrix_mul/bfloat16.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/cuda/handle.h" #include "src/cuda/matrix_mul/algos.h" #include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16( MatrixMulForwardImpl::AlgoBase* algorithm) : m_algorithm(algorithm) { megdnn_assert_internal(algorithm); m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name()); } MatrixMulForwardImpl::AlgoBase::SizeArgs MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const { auto new_args = args; auto change_dtype = [](TensorLayout& layout) { if (layout.dtype == dtype::BFloat16()) { layout.dtype = dtype::Float32(); } }; change_dtype(new_args.layout_a); change_dtype(new_args.layout_b); change_dtype(new_args.layout_c); return new_args; } bool MatrixMulForwardImpl::AlgoBFloat16::is_available( const SizeArgs& args) const { auto fargs = float_args(args); return args.layout_a.dtype == dtype::BFloat16() && m_algorithm->is_available(fargs); } WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle( void* ptr, const SizeArgs& args) const { auto fargs = float_args(args); SmallVector sizes; auto get_workspace = [&sizes](const TensorLayout& src) { TensorLayout dst = src; if (dst.dtype == dtype::BFloat16()) { dst.dtype = dtype::Float32(); sizes.push_back(dst.span().dist_byte()); } }; get_workspace(args.layout_a); get_workspace(args.layout_b); get_workspace(args.layout_c); sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); return {ptr, std::move(sizes)}; } size_t MatrixMulForwardImpl::AlgoBFloat16::get_workspace_in_bytes( const SizeArgs& args) const { return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { TensorND a = args.tensor_a; TensorND b = args.tensor_b; TensorND c = args.tensor_c; auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); auto ctypecvt = CompTypeCvter( args.opr->handle(), &bundle); ctypecvt.src_to_comp_type(args.tensor_a, a) .src_to_comp_type(args.tensor_b, b) .src_to_comp_type(args.tensor_c, c); { auto matmul_opr = args.opr->handle()->create_operator(); matmul_opr->param() = args.opr->param(); matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; matmul_opr->execution_policy() = {m_algorithm}; matmul_opr->exec(a, b, c, ctypecvt.workspace()); } ctypecvt.comp_to_dst_type(c, args.tensor_c); } // vim: syntax=cpp.doxygen