GitOrigin-RevId: 95be929841
tags/v1.5.0
| @@ -96,34 +96,6 @@ void ConvDesc::set(const param::Convolution& param, const size_t nr_group, | |||||
| //! not supported | //! not supported | ||||
| } | } | ||||
| PoolingDesc::PoolingDesc() { | |||||
| miopen_check(miopenCreatePoolingDescriptor(&desc)); | |||||
| } | |||||
| PoolingDesc::~PoolingDesc() { | |||||
| miopen_check(miopenDestroyPoolingDescriptor(desc)); | |||||
| } | |||||
| void PoolingDesc::set(const param::Pooling& param) { | |||||
| miopenPoolingMode_t mode; | |||||
| switch (param.mode) { | |||||
| case param::Pooling::Mode::MAX: | |||||
| mode = miopenPoolingMax; | |||||
| break; | |||||
| case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: | |||||
| mode = miopenPoolingAverage; | |||||
| break; | |||||
| case param::Pooling::Mode::AVERAGE: | |||||
| mode = miopenPoolingAverageInclusive; | |||||
| break; | |||||
| default: | |||||
| megdnn_throw("Unsupported pooling mode for miopen"); | |||||
| } | |||||
| miopen_check(miopenSet2dPoolingDescriptor( | |||||
| desc, mode, param.window_h, param.window_w, param.pad_h, | |||||
| param.pad_w, param.stride_h, param.stride_w)); | |||||
| } | |||||
| LRNDesc::LRNDesc() { | LRNDesc::LRNDesc() { | ||||
| miopen_check(miopenCreateLRNDescriptor(&desc)); | miopen_check(miopenCreateLRNDescriptor(&desc)); | ||||
| } | } | ||||
| @@ -38,14 +38,6 @@ public: | |||||
| miopenConvolutionDescriptor_t desc; | miopenConvolutionDescriptor_t desc; | ||||
| }; | }; | ||||
| class PoolingDesc { | |||||
| public: | |||||
| PoolingDesc(); | |||||
| void set(const param::Pooling& param); | |||||
| ~PoolingDesc(); | |||||
| miopenPoolingDescriptor_t desc; | |||||
| }; | |||||
| class LRNDesc { | class LRNDesc { | ||||
| public: | public: | ||||
| LRNDesc(); | LRNDesc(); | ||||
| @@ -0,0 +1,209 @@ | |||||
| /** | |||||
| * \file dnn/src/rocm/pooling/algos.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "./algo.h" | |||||
| #include "hcc_detail/hcc_defs_prologue.h" | |||||
| #include "src/rocm/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace rocm; | |||||
| PoolingForwardImpl::AlgoPack::AlgoPack() { | |||||
| all_algos.push_back(&algo_miopen); | |||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| PoolingForwardImpl::AlgoPack PoolingForwardImpl::sm_algo_pack; | |||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingForwardImpl) | |||||
| PoolingForwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingForwardImpl* o, | |||||
| const TensorLayout& src, | |||||
| const TensorLayout& dst) | |||||
| : handle{concrete_handle(o->handle())}, | |||||
| opr{o}, | |||||
| layout_src{&src}, | |||||
| layout_dst{&dst} {} | |||||
| PoolingForwardImpl::AlgoBase::ExecArgs::ExecArgs(PoolingForwardImpl* opr, | |||||
| _megdnn_tensor_in src, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) | |||||
| : SizeArgs(opr, src.layout, dst.layout), | |||||
| src_tensor{&src}, | |||||
| dst_tensor{&dst}, | |||||
| workspace{workspace} {} | |||||
| std::string PoolingForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
| return ssprintf("src=%s, dst=%s", layout_src->to_string().c_str(), | |||||
| layout_dst->to_string().c_str()); | |||||
| } | |||||
| bool PoolingForwardImpl::AlgoMIOpen::is_available(const SizeArgs& args) const { | |||||
| return true; | |||||
| } | |||||
| void PoolingForwardImpl::AlgoMIOpen::init_mode( | |||||
| const ExecArgs& args, miopenPoolingMode_t& mode) const { | |||||
| switch (args.opr->param().mode) { | |||||
| case param::Pooling::Mode::MAX: | |||||
| mode = miopenPoolingMax; | |||||
| break; | |||||
| case param::Pooling::Mode::AVERAGE: | |||||
| mode = miopenPoolingAverage; | |||||
| break; | |||||
| case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: | |||||
| mode = miopenPoolingAverageInclusive; | |||||
| break; | |||||
| default: | |||||
| megdnn_throw(ssprintf("Unspport pooling mode : {%d}", | |||||
| static_cast<int>(args.opr->param().mode))); | |||||
| } | |||||
| } | |||||
| size_t PoolingForwardImpl::AlgoMIOpen::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| return 0; | |||||
| } | |||||
| void PoolingForwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const { | |||||
| auto handle = miopen_handle(args.handle); | |||||
| TensorDesc src_desc, dst_desc; | |||||
| args.init_desc(src_desc, dst_desc); | |||||
| miopenPoolingMode_t mode; | |||||
| init_mode(args, mode); | |||||
| miopenPoolingDescriptor_t miopen_desc; | |||||
| miopen_check(miopenCreatePoolingDescriptor(&miopen_desc)); | |||||
| miopen_check(miopenSet2dPoolingDescriptor( | |||||
| miopen_desc, mode, args.opr->param().window_h, | |||||
| args.opr->param().window_w, args.opr->param().pad_h, | |||||
| args.opr->param().pad_w, args.opr->param().stride_h, | |||||
| args.opr->param().stride_w)); | |||||
| dt_float32 alpha = 1.0f, beta = 0.0f; | |||||
| miopen_check(miopenPoolingForward( | |||||
| handle, miopen_desc, &alpha, src_desc.desc, | |||||
| args.src_tensor->raw_ptr, &beta, dst_desc.desc, | |||||
| args.src_tensor->raw_ptr, false, nullptr, 0_z)); | |||||
| miopen_check(miopenDestroyPoolingDescriptor(miopen_desc)); | |||||
| } | |||||
| PoolingBackwardImpl::AlgoPack::AlgoPack() { | |||||
| all_algos.push_back(&algo_miopen); | |||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| PoolingBackwardImpl::AlgoPack PoolingBackwardImpl::sm_algo_pack; | |||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingBackwardImpl) | |||||
| PoolingBackwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingBackwardImpl* o, | |||||
| const TensorLayout& src, | |||||
| const TensorLayout& dst, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) | |||||
| : handle{concrete_handle(o->handle())}, | |||||
| opr{o}, | |||||
| layout_src{&src}, | |||||
| layout_dst{&dst}, | |||||
| layout_diff{&diff}, | |||||
| layout_grad{&grad} {} | |||||
| PoolingBackwardImpl::AlgoBase::ExecArgs::ExecArgs(PoolingBackwardImpl* opr, | |||||
| _megdnn_tensor_in src, | |||||
| _megdnn_tensor_in dst, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) | |||||
| : SizeArgs(opr, src.layout, dst.layout, diff.layout, grad.layout), | |||||
| src_tensor{&src}, | |||||
| dst_tensor{&dst}, | |||||
| diff_tensor{&diff}, | |||||
| grad_tensor{&grad}, | |||||
| workspace{workspace} {} | |||||
| std::string PoolingBackwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
| return ssprintf( | |||||
| "src=%s, dst=%s, diff=%s, grad=%s", layout_src->to_string().c_str(), | |||||
| layout_dst->to_string().c_str(), layout_diff->to_string().c_str(), | |||||
| layout_grad->to_string().c_str()); | |||||
| } | |||||
| bool PoolingBackwardImpl::AlgoMIOpen::is_available(const SizeArgs&) const { | |||||
| return true; | |||||
| } | |||||
| size_t PoolingBackwardImpl::AlgoMIOpen::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| TensorDesc dst_desc; | |||||
| dst_desc.set(*args.layout_dst); | |||||
| size_t ws_size = 0_z; | |||||
| miopenPoolingGetWorkSpaceSize(dst_desc.desc, &ws_size); | |||||
| return ws_size; | |||||
| } | |||||
| void PoolingBackwardImpl::AlgoMIOpen::init_mode(const ExecArgs& args, | |||||
| miopenPoolingMode_t& mode) const { | |||||
| switch (args.opr->param().mode) { | |||||
| case param::Pooling::Mode::MAX: | |||||
| mode = miopenPoolingMax; | |||||
| break; | |||||
| case param::Pooling::Mode::AVERAGE: | |||||
| mode = miopenPoolingAverage; | |||||
| break; | |||||
| case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: | |||||
| mode = miopenPoolingAverageInclusive; | |||||
| break; | |||||
| default: | |||||
| megdnn_throw(ssprintf("Unspport pooling mode : {%d}", | |||||
| static_cast<int>(args.opr->param().mode))); | |||||
| } | |||||
| } | |||||
| void PoolingBackwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const { | |||||
| auto handle = miopen_handle(args.handle); | |||||
| TensorDesc src_desc, dst_desc, diff_desc, grad_desc; | |||||
| args.init_desc(src_desc, dst_desc, diff_desc, grad_desc); | |||||
| miopenPoolingMode_t mode; | |||||
| init_mode(args, mode); | |||||
| miopenPoolingDescriptor_t miopen_desc; | |||||
| miopen_check(miopenCreatePoolingDescriptor(&miopen_desc)); | |||||
| miopen_check(miopenSet2dPoolingDescriptor( | |||||
| miopen_desc, mode, args.opr->param().window_h, | |||||
| args.opr->param().window_w, args.opr->param().pad_h, | |||||
| args.opr->param().pad_w, args.opr->param().stride_h, | |||||
| args.opr->param().stride_w)); | |||||
| float alpha = 1.0f, beta = 0.0f; | |||||
| if (args.opr->param().mode == param::Pooling::Mode::MAX) { | |||||
| //! FIXME: when using max pooling opr, the backward opr need the indices | |||||
| //! of the forward opr which stored in workspace. We have to recompute | |||||
| //! the indices by calling miopenPoolingForward again. | |||||
| miopen_check(miopenPoolingForward( | |||||
| handle, miopen_desc, &alpha, src_desc.desc, | |||||
| args.src_tensor->raw_ptr, &beta, dst_desc.desc, | |||||
| args.dst_tensor->raw_ptr, true, args.workspace.raw_ptr, | |||||
| args.workspace.size)); | |||||
| } | |||||
| miopen_check(miopenPoolingBackward( | |||||
| handle, miopen_desc, &alpha, dst_desc.desc, | |||||
| args.dst_tensor->raw_ptr, diff_desc.desc, args.diff_tensor->raw_ptr, | |||||
| src_desc.desc, args.src_tensor->raw_ptr, &beta, grad_desc.desc, | |||||
| args.grad_tensor->raw_ptr, args.workspace.raw_ptr)); | |||||
| } | |||||
| @@ -0,0 +1,195 @@ | |||||
| /** | |||||
| * \file dnn/src/rocm/pooling/algo.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <unordered_map> | |||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/rocm/miopen_wrapper.h" | |||||
| #include "src/rocm/pooling/opr_impl.h" | |||||
| #include "src/rocm/handle.h" | |||||
| namespace megdnn { | |||||
| namespace rocm { | |||||
| class PoolingForwardImpl::AlgoBase : public Algorithm { | |||||
| public: | |||||
| enum class AlgoType : uint32_t { ROCM_MIOPEN }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||||
| struct SizeArgs { | |||||
| HandleImpl* handle; | |||||
| PoolingForwardImpl* opr; | |||||
| const TensorLayout *layout_src, *layout_dst; | |||||
| std::string to_string() const; | |||||
| void init_desc(TensorDesc& src_desc, TensorDesc& dst_desc) const { | |||||
| src_desc.set(*layout_src, opr->param().format); | |||||
| dst_desc.set(*layout_dst, opr->param().format); | |||||
| } | |||||
| SizeArgs(PoolingForwardImpl* opr, const TensorLayout& src, | |||||
| const TensorLayout& dst); | |||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| const TensorND *src_tensor, *dst_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(PoolingForwardImpl* opr, _megdnn_tensor_in src, | |||||
| _megdnn_tensor_out dst, _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void exec(const ExecArgs& args) const = 0; | |||||
| bool is_available_attribute( | |||||
| const SizeArgs& args, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && is_available(args); | |||||
| } | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| }; | |||||
| class PoolingForwardImpl::AlgoMIOpen final : public AlgoBase { | |||||
| std::string m_algo_name; | |||||
| AlgoAttribute m_algo_attribute; | |||||
| public: | |||||
| AlgoMIOpen(AlgoAttribute attr) | |||||
| : m_algo_name("MIOpenPoolingForward"), m_algo_attribute(attr) {} | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void init_mode(const ExecArgs& args, miopenPoolingMode_t& mode) const; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { return m_algo_name.c_str(); } | |||||
| AlgoAttribute attribute() const override { return m_algo_attribute; } | |||||
| MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algo_attribute, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | |||||
| class PoolingForwardImpl::AlgoPack : NonCopyableObj { | |||||
| private: | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| AlgoPack(); | |||||
| AlgoMIOpen algo_miopen{AlgoAttribute::REPRODUCIBLE}; | |||||
| std::vector<AlgoBase*> all_algos; | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | |||||
| class PoolingBackwardImpl::AlgoBase : public Algorithm { | |||||
| public: | |||||
| enum class AlgoType : uint32_t { ROCM_MIOPEN }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||||
| struct SizeArgs { | |||||
| HandleImpl* handle; | |||||
| PoolingBackwardImpl* opr; | |||||
| const TensorLayout *layout_src, *layout_dst, *layout_diff, *layout_grad; | |||||
| std::string to_string() const; | |||||
| void init_desc(TensorDesc& src_desc, TensorDesc& dst_desc, | |||||
| TensorDesc& diff_desc, TensorDesc& grad_desc) const { | |||||
| src_desc.set(*layout_src); | |||||
| dst_desc.set(*layout_dst); | |||||
| diff_desc.set(*layout_diff); | |||||
| grad_desc.set(*layout_grad); | |||||
| } | |||||
| SizeArgs(PoolingBackwardImpl* opr, const TensorLayout& src, | |||||
| const TensorLayout& dst, const TensorLayout& diff, | |||||
| const TensorLayout& grad); | |||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| const TensorND *src_tensor, *dst_tensor, *diff_tensor, *grad_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(PoolingBackwardImpl* opr, _megdnn_tensor_in src, | |||||
| _megdnn_tensor_in dst, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void exec(const ExecArgs& args) const = 0; | |||||
| bool is_available_attribute( | |||||
| const SizeArgs& args, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && is_available(args); | |||||
| } | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| }; | |||||
| class PoolingBackwardImpl::AlgoMIOpen final : public AlgoBase { | |||||
| std::string m_algo_name; | |||||
| AlgoAttribute m_algo_attribute; | |||||
| public: | |||||
| AlgoMIOpen(AlgoAttribute attr) | |||||
| : m_algo_name("MIOpenPoolingBackward"), m_algo_attribute(attr) {} | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void init_mode(const ExecArgs& args, miopenPoolingMode_t& mode) const; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { return m_algo_name.c_str(); } | |||||
| AlgoAttribute attribute() const override { | |||||
| return m_algo_attribute; | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algo_attribute, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | |||||
| class PoolingBackwardImpl::AlgoPack : NonCopyableObj { | |||||
| private: | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| AlgoPack(); | |||||
| AlgoMIOpen algo_miopen{AlgoAttribute::REPRODUCIBLE}; | |||||
| std::vector<AlgoBase*> all_algos; | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | |||||
| } // namespace rocm | |||||
| } // namespace megdnn | |||||
| @@ -10,18 +10,47 @@ | |||||
| */ | */ | ||||
| #include "hcc_detail/hcc_defs_prologue.h" | #include "hcc_detail/hcc_defs_prologue.h" | ||||
| #include "src/rocm/pooling/opr_impl.h" | #include "src/rocm/pooling/opr_impl.h" | ||||
| #include "src/rocm/utils.h" | #include "src/rocm/utils.h" | ||||
| #include "./algo.h" | |||||
| #include "src/common/algo_chooser.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace rocm { | namespace rocm { | ||||
| void PoolingForwardImpl::setup_descs(const TensorLayout &src, | |||||
| const TensorLayout &dst) | |||||
| { | |||||
| src_desc.set(src, param().format); | |||||
| dst_desc.set(dst, param().format); | |||||
| pooling_desc.set(this->param()); | |||||
| size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst) { | |||||
| AlgoBase::SizeArgs args(this, src, dst); | |||||
| return get_algorithm(this, src, dst)->get_workspace_in_bytes(args); | |||||
| } | |||||
| const char* PoolingForwardImpl::get_algorithm_set_name() const { | |||||
| return "ROCM_POOLING_FORWARD"; | |||||
| } | |||||
| std::vector<PoolingForwardImpl::Algorithm*> | |||||
| PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
| const TensorLayout& dst) { | |||||
| return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | |||||
| } | |||||
| PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | |||||
| AlgoBase::SizeArgs args(this, src, dst); | |||||
| for (auto&& iter : sm_algo_pack.all_algos) { | |||||
| if (iter->is_available_attribute(args, positive_attr, negative_attr)) { | |||||
| return iter; | |||||
| } | |||||
| } | |||||
| megdnn_throw( | |||||
| ssprintf("require algorithm with attribute(%s) and without " | |||||
| "attribute(%s), but can't get suitable algo.\n", | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| Algorithm::attribute_str(negative_attr).c_str())); | |||||
| return nullptr; | |||||
| } | } | ||||
| void PoolingForwardImpl::exec(_megdnn_tensor_in src, | void PoolingForwardImpl::exec(_megdnn_tensor_in src, | ||||
| @@ -29,24 +58,52 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, | |||||
| _megdnn_workspace workspace) | _megdnn_workspace workspace) | ||||
| { | { | ||||
| check_exec(src.layout, dst.layout, workspace.size); | check_exec(src.layout, dst.layout, workspace.size); | ||||
| auto handle = miopen_handle(this->handle()); | |||||
| setup_descs(src.layout, dst.layout); | |||||
| dt_float32 alpha = 1.0f, beta = 0.0f; | |||||
| miopen_check(miopenPoolingForward(handle, pooling_desc.desc, &alpha, | |||||
| src_desc.desc, src.raw_ptr, &beta, | |||||
| dst_desc.desc, dst.raw_ptr, false, | |||||
| nullptr, 0_z)); | |||||
| { | |||||
| AlgoBase::ExecArgs args(this, src, dst, workspace); | |||||
| auto algo = get_algorithm(this, src.layout, dst.layout); | |||||
| algo->exec(args); | |||||
| } | |||||
| } | } | ||||
| void PoolingBackwardImpl::setup_descs(const TensorLayout& src, | |||||
| const TensorLayout& dst, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| src_desc.set(src); | |||||
| dst_desc.set(dst); | |||||
| diff_desc.set(diff); | |||||
| grad_desc.set(grad); | |||||
| pooling_desc.set(this->param()); | |||||
| size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| AlgoBase::SizeArgs args(this, src, dst, diff, grad); | |||||
| return get_algorithm(this, src, dst, diff, grad) | |||||
| ->get_workspace_in_bytes(args); | |||||
| }; | |||||
| const char* PoolingBackwardImpl::get_algorithm_set_name() const { | |||||
| return "ROCM_POOLING_BACKWARD"; | |||||
| } | |||||
| std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& diff, const TensorLayout& grad) { | |||||
| return megdnn::get_all_algorithms<PoolingBackwardImpl>( | |||||
| {this, src, dst, diff, grad}); | |||||
| } | |||||
| Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& diff, const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | |||||
| AlgoBase::SizeArgs args(this, src, dst, diff, grad); | |||||
| for (auto iter : sm_algo_pack.all_algos) { | |||||
| if (iter->is_available_attribute(args, positive_attr, negative_attr)) { | |||||
| return iter; | |||||
| } | |||||
| } | |||||
| megdnn_throw( | |||||
| ssprintf("require algorithm with attribute(%s) and without " | |||||
| "attribute(%s), but can't get suitable algo.\n", | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| Algorithm::attribute_str(negative_attr).c_str())); | |||||
| return nullptr; | |||||
| } | } | ||||
| void PoolingBackwardImpl::exec(_megdnn_tensor_in src, | void PoolingBackwardImpl::exec(_megdnn_tensor_in src, | ||||
| @@ -55,35 +112,16 @@ void PoolingBackwardImpl::exec(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_out grad, | _megdnn_tensor_out grad, | ||||
| _megdnn_workspace workspace) | _megdnn_workspace workspace) | ||||
| { | { | ||||
| check_exec(src.layout, dst.layout, diff.layout, grad.layout, workspace.size); | |||||
| auto handle = miopen_handle(this->handle()); | |||||
| setup_descs(src.layout, dst.layout, diff.layout, grad.layout); | |||||
| float alpha = 1.0f, beta = 0.0f; | |||||
| if (param().mode == param::Pooling::Mode::MAX) { | |||||
| //! FIXME: when using max pooling opr, the backward opr need the indices | |||||
| //! of the forward opr which stored in workspace. We have to recompute | |||||
| //! the indices by calling miopenPoolingForward again. | |||||
| miopen_check(miopenPoolingForward(handle, pooling_desc.desc, &alpha, | |||||
| src_desc.desc, src.raw_ptr, &beta, | |||||
| dst_desc.desc, dst.raw_ptr, true, | |||||
| workspace.raw_ptr, workspace.size)); | |||||
| check_exec(src.layout, dst.layout, diff.layout, grad.layout, | |||||
| workspace.size); | |||||
| { | |||||
| AlgoBase::ExecArgs args(this, src, dst, diff, grad, workspace); | |||||
| auto algo = get_algorithm(this, src.layout, dst.layout, diff.layout, | |||||
| grad.layout); | |||||
| algo->exec(args); | |||||
| } | } | ||||
| miopen_check(miopenPoolingBackward( | |||||
| handle, pooling_desc.desc, &alpha, dst_desc.desc, dst.raw_ptr, | |||||
| diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, &beta, | |||||
| grad_desc.desc, grad.raw_ptr, workspace.raw_ptr)); | |||||
| } | } | ||||
| size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| setup_descs(src, dst, diff, grad); | |||||
| size_t ws_size = 0_z; | |||||
| miopenPoolingGetWorkSpaceSize(dst_desc.desc, &ws_size); | |||||
| return ws_size; | |||||
| }; | |||||
| } // namespace rocm | } // namespace rocm | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -22,13 +22,37 @@ class PoolingForwardImpl final: public PoolingForward { | |||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | ||||
| _megdnn_workspace workspace) override; | _megdnn_workspace workspace) override; | ||||
| size_t get_workspace_in_bytes(const TensorLayout &, | size_t get_workspace_in_bytes(const TensorLayout &, | ||||
| const TensorLayout &) override { | |||||
| return 0; | |||||
| const TensorLayout &) override; | |||||
| const char* get_algorithm_set_name() const override; | |||||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(src, dst, workspace_limit_in_bytes, | |||||
| positive_attr, negative_attr) | |||||
| ->info(); | |||||
| } | } | ||||
| class AlgoBase; | |||||
| class AlgoMIOpen; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| TensorDesc src_desc, dst_desc; | |||||
| PoolingDesc pooling_desc; | |||||
| void setup_descs(const TensorLayout &src, const TensorLayout &dst); | |||||
| static AlgoPack sm_algo_pack; | |||||
| }; | }; | ||||
| class PoolingBackwardImpl final: public PoolingBackward { | class PoolingBackwardImpl final: public PoolingBackward { | ||||
| @@ -43,14 +67,41 @@ class PoolingBackwardImpl final: public PoolingBackward { | |||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| private: | |||||
| TensorDesc src_desc, dst_desc, diff_desc, grad_desc; | |||||
| PoolingDesc pooling_desc; | |||||
| void setup_descs(const TensorLayout &src, | |||||
| const TensorLayout &dst, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad); | |||||
| const char* get_algorithm_set_name() const override; | |||||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& diff, const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(src, dst, diff, grad, | |||||
| workspace_limit_in_bytes, | |||||
| positive_attr, negative_attr) | |||||
| ->info(); | |||||
| } | |||||
| class AlgoBase; | |||||
| class AlgoMIOpen; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& diff, const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& diff, const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | |||||
| static AlgoPack sm_algo_pack; | |||||
| }; | }; | ||||
| } // namespace rocm | } // namespace rocm | ||||