Browse Source

perf(imperative/src): improve dot performance

GitOrigin-RevId: 35b5bd164f
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
2b80806f21
4 changed files with 69 additions and 4 deletions
  1. +2
    -1
      dnn/include/megdnn/oprs/linalg.h
  2. +1
    -1
      dnn/src/common/dot.cpp
  3. +65
    -1
      imperative/src/impl/ops/specializations.cpp
  4. +1
    -1
      src/opr/include/megbrain/opr/blas.h

+ 2
- 1
dnn/include/megdnn/oprs/linalg.h View File

@@ -150,7 +150,8 @@ public:
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;



+ 1
- 1
dnn/src/common/dot.cpp View File

@@ -33,7 +33,7 @@ void DotForward::check_exec(
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

void DotForward::deduce_layout(
MGE_WIN_DECLSPEC_FUC void DotForward::deduce_layout(
const TensorLayout& A, const TensorLayout&, TensorLayout& C) {
C = TensorLayout(TensorShape{1}, A.dtype);
}


+ 65
- 1
imperative/src/impl/ops/specializations.cpp View File

@@ -39,6 +39,7 @@
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"

#include "../blob_manager_impl.h"
#include "../op_trait.h"

namespace mgb::imperative {
@@ -319,7 +320,70 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
OperatorNodeConfig config{op.make_name()};
return opr::Dot::make(inputs[0], inputs[1], config);
}
OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback();

// std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
// auto* node = &node_->cast_final_safe<opr::Dot>();
// return Dot::make(node->param());
// }

SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto a = inputs[0]->layout();
auto comp_node = inputs[0]->comp_node();
using TensorND = megdnn::TensorND;
SmallVector<TensorND> inp_tensornds;
inp_tensornds.reserve(inputs.size());
auto dnn_opr = opr::intl::create_megdnn_opr<megdnn::Dot>(comp_node);
for (unsigned i = 0; i < inputs.size(); ++i) {
auto dnn_ten = inputs[i]->dnn_tensor();
inp_tensornds.push_back(dnn_ten);
}
TensorLayout oup_layout{inputs[0]->dtype()};
auto inp1_tensor = inputs[0]->dnn_tensor();
auto inp2_tensor = inputs[1]->dnn_tensor();
dnn_opr->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout);

if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) {
auto fill_opr = opr::intl::create_megdnn_opr<megdnn::Fill>(comp_node);
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
fill_opr->param() = 0;
fill_opr->exec(out.as_megdnn(), {});
return {Tensor::make(out)};
}

auto wk_size = dnn_opr->get_workspace_in_bytes(
inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout);

DeviceTensorND out_devtensor =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
TensorLayout wk_layout{TensorShape{wk_size}, inputs[0]->dtype()};
DeviceTensorND workspace =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, wk_layout);
megdnn::Workspace dnn_wk(workspace.raw_ptr(), wk_size);

dnn_opr->exec(
inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk);

return {Tensor::make(out_devtensor)};
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<Dot>();
SmallVector<LogicalTensorDesc> dests(1);
dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype);
dests[0].comp_node = inputs[0].comp_node;
return {dests, true};
}

OP_TRAIT_REG(Dot, Dot, opr::Dot)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();

} // namespace dot
} // namespace



+ 1
- 1
src/opr/include/megbrain/opr/blas.h View File

@@ -88,7 +88,7 @@ private:
/*!
* \brief dot product of two tensors
*/
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
Dot, cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolderImpl<megdnn::Dot>>) // {
public:
MGE_WIN_DECLSPEC_FUC Dot(


Loading…
Cancel
Save