|
|
|
@@ -9,10 +9,16 @@ |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "megbrain/graph/symbol_var.h" |
|
|
|
#include "megbrain/imperative/ops/autogen.h" |
|
|
|
#include "megbrain/imperative/proxy_graph_detail.h" |
|
|
|
#include "megbrain/opr/basic_arith.h" |
|
|
|
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" |
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
#include "megbrain/opr/tensor_manip.h" |
|
|
|
#include "megdnn/dtype.h" |
|
|
|
|
|
|
|
#include "../blob_manager_impl.h" |
|
|
|
#include "../dnn_op_helper.h" |
|
|
|
#include "../op_trait.h" |
|
|
|
|
|
|
|
@@ -22,18 +28,41 @@ namespace { |
|
|
|
namespace reduce { |
|
|
|
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { |
|
|
|
auto&& reduce = static_cast<const Reduce&>(def); |
|
|
|
OperatorNodeConfig config{reduce.make_name()}; |
|
|
|
auto comp_node = inputs[0]->comp_node(); |
|
|
|
OperatorNodeConfig config{reduce.make_name(), comp_node, inputs[0]->dtype()}; |
|
|
|
|
|
|
|
if (inputs.size() > 1) { |
|
|
|
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); |
|
|
|
} else { |
|
|
|
return opr::Reduce::make( |
|
|
|
inputs[0], reduce.param(), (cg::VarNode*)nullptr, config); |
|
|
|
} |
|
|
|
|
|
|
|
using Param = megdnn::param::Reduce; |
|
|
|
auto param = reduce.param(); |
|
|
|
if (param.axis < 0) { |
|
|
|
param.axis = inputs[0]->shape().ndim + param.axis; |
|
|
|
} |
|
|
|
|
|
|
|
SymbolVar target_shape = (cg::VarNode*)nullptr; |
|
|
|
if (param.axis == INT_MAX) { |
|
|
|
DTypeScalar vi{1}; |
|
|
|
// auto graph = ComputingGraph::make(); |
|
|
|
auto graph = inputs[0]->owner_graph(); |
|
|
|
target_shape = opr::ImmutableTensor::make(*graph, vi, config); |
|
|
|
} |
|
|
|
auto res = opr::Reduce::make(inputs[0], param, target_shape, config); |
|
|
|
if (!reduce.keepdim && param.axis != INT_MAX) { |
|
|
|
using Desc = opr::AxisAddRemove::AxisDesc; |
|
|
|
std::vector<Desc> remove_param; |
|
|
|
remove_param.push_back(Desc::make_remove(param.axis)); |
|
|
|
OperatorNodeConfig remove_config{ |
|
|
|
def.make_name(), comp_node, inputs[0]->dtype()}; |
|
|
|
return opr::AxisAddRemove::make(res, remove_param, remove_config); |
|
|
|
} |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { |
|
|
|
auto* node = &node_->cast_final_safe<opr::Reduce>(); |
|
|
|
return Reduce::make(node->param()); |
|
|
|
return Reduce::make(node->param(), true); |
|
|
|
} |
|
|
|
|
|
|
|
// TODO: using this for apply_on_physical_tensor |
|
|
|
@@ -57,21 +86,159 @@ SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
return {Tensor::make( |
|
|
|
inputs[0]->blob(), inputs[0]->offset(), inputs[0]->layout())}; |
|
|
|
} |
|
|
|
return proxy_graph_detail::apply_on_physical_tensor( |
|
|
|
def, inputs, output_descs, validated); |
|
|
|
|
|
|
|
auto size = inputs.size(); |
|
|
|
if (size > 1) { |
|
|
|
return proxy_graph_detail::apply_on_physical_tensor( |
|
|
|
def, inputs, output_descs, validated); |
|
|
|
} |
|
|
|
|
|
|
|
auto comp_node = inputs[0]->comp_node(); |
|
|
|
using TensorND = megdnn::TensorND; |
|
|
|
auto&& op_def = def.cast_final_safe<Reduce>(); |
|
|
|
SmallVector<TensorND> inp_tensornds; |
|
|
|
inp_tensornds.reserve(inputs.size()); |
|
|
|
auto src = inputs[0]->layout(); |
|
|
|
|
|
|
|
DnnOprCaller<megdnn::Reduce> dnn_op(comp_node); |
|
|
|
dnn_op.op->param() = op_def.param(); |
|
|
|
auto axis = op_def.param().axis; |
|
|
|
auto keepdim = op_def.keepdim; |
|
|
|
|
|
|
|
if (axis < 0) { |
|
|
|
axis = inputs[0]->layout().ndim + axis; |
|
|
|
} |
|
|
|
|
|
|
|
dnn_op.op->param().axis = axis == INT_MAX ? 0 : axis; |
|
|
|
|
|
|
|
if (axis == INT_MAX) { |
|
|
|
src.shape[0] = src.total_nr_elems(); |
|
|
|
src.ndim = 1; |
|
|
|
src.init_contiguous_stride(); |
|
|
|
} |
|
|
|
TensorLayout layout{src.dtype}; |
|
|
|
dnn_op.op->deduce_layout(src, layout); |
|
|
|
|
|
|
|
if (inputs[0]->layout().is_empty()) { |
|
|
|
inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src); |
|
|
|
|
|
|
|
auto mode = op_def.param().mode; |
|
|
|
DnnOprCaller<megdnn::Fill> fill_op(comp_node); |
|
|
|
|
|
|
|
if (!keepdim && src.ndim > 1) { |
|
|
|
layout.remove_axis_inplace(axis); |
|
|
|
layout.init_contiguous_stride(); |
|
|
|
} |
|
|
|
DeviceTensorND out = |
|
|
|
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout); |
|
|
|
std::string err_msg; |
|
|
|
switch (mode) { |
|
|
|
case Reduce::Mode::SUM: |
|
|
|
if (!out.empty()) { |
|
|
|
fill_op.op->param() = 0; |
|
|
|
fill_op.op->exec(out.as_megdnn(), {}); |
|
|
|
} |
|
|
|
break; |
|
|
|
case Reduce::Mode::PRODUCT: |
|
|
|
if (!out.empty()) { |
|
|
|
fill_op.op->param() = 1; |
|
|
|
fill_op.op->exec(out.as_megdnn(), {}); |
|
|
|
} |
|
|
|
break; |
|
|
|
case Reduce::Mode::MEAN: |
|
|
|
err_msg = "mean"; |
|
|
|
break; |
|
|
|
case Reduce::Mode::MIN: |
|
|
|
err_msg = "min"; |
|
|
|
break; |
|
|
|
case Reduce::Mode::MAX: |
|
|
|
err_msg = "max"; |
|
|
|
break; |
|
|
|
case Reduce::Mode::SUM_SQR: |
|
|
|
err_msg = "sum_sqr"; |
|
|
|
break; |
|
|
|
default: |
|
|
|
mgb_throw(MegBrainError, "bad reduce mode"); |
|
|
|
} |
|
|
|
if (!err_msg.empty()) { |
|
|
|
mgb_throw( |
|
|
|
MegBrainError, "empty input is not allowed for reduce mode: %s", |
|
|
|
err_msg.c_str()); |
|
|
|
} |
|
|
|
return {Tensor::make(out)}; |
|
|
|
} |
|
|
|
|
|
|
|
auto dnn_ten = inputs[0]->dnn_tensor(); |
|
|
|
dnn_ten.layout = src; |
|
|
|
inp_tensornds.push_back(dnn_ten); |
|
|
|
|
|
|
|
megdnn::Workspace dnn_wk; |
|
|
|
|
|
|
|
auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout); |
|
|
|
if (wk_size != 0) { |
|
|
|
auto wk = Blob::make(comp_node, wk_size); |
|
|
|
dnn_wk.raw_ptr = wk->storage().get(); |
|
|
|
dnn_wk.size = wk_size; |
|
|
|
} |
|
|
|
|
|
|
|
DeviceTensorND out = |
|
|
|
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout); |
|
|
|
|
|
|
|
dnn_op.op->exec(inp_tensornds[0], out.as_megdnn(), dnn_wk); |
|
|
|
|
|
|
|
if (!keepdim && src.ndim > 1) { |
|
|
|
auto out_layout = out.layout(); |
|
|
|
out_layout.remove_axis_inplace(axis); |
|
|
|
out_layout.init_contiguous_stride(); |
|
|
|
out.resize(out_layout); |
|
|
|
} |
|
|
|
|
|
|
|
return {Tensor::make(out)}; |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
auto [output_descs, validated] = |
|
|
|
proxy_graph_detail::infer_output_attrs_fallible(def, inputs); |
|
|
|
if (inputs.size() == 2 && !output_descs[0].layout.ndim) { |
|
|
|
auto&& op_def = def.cast_final_safe<Reduce>(); |
|
|
|
auto axis = op_def.param().axis; |
|
|
|
auto keepdim = op_def.keepdim; |
|
|
|
|
|
|
|
size_t size = inputs.size(); |
|
|
|
SmallVector<LogicalTensorDesc> dests(size); |
|
|
|
|
|
|
|
if (size > 1) { |
|
|
|
auto [output_descs, validated] = |
|
|
|
proxy_graph_detail::infer_output_attrs_fallible(def, inputs); |
|
|
|
if (!inputs[1].value.empty()) { |
|
|
|
cg::copy_tensor_value_to_shape(output_descs[0].layout, inputs[1].value); |
|
|
|
output_descs[0].layout.init_contiguous_stride(); |
|
|
|
} |
|
|
|
return {output_descs, validated}; |
|
|
|
} |
|
|
|
|
|
|
|
if (axis < 0) { |
|
|
|
axis = inputs[0].layout.ndim + axis; |
|
|
|
} |
|
|
|
|
|
|
|
if (axis == INT_MAX || inputs[0].layout.ndim == 1) { |
|
|
|
TensorLayout layout{inputs[0].layout.dtype}; |
|
|
|
layout.shape[0] = 1; |
|
|
|
layout.ndim = 1; |
|
|
|
dests[0].layout = layout; |
|
|
|
dests[0].comp_node = inputs[0].comp_node; |
|
|
|
} else { |
|
|
|
for (size_t i = 0; i < size; ++i) { |
|
|
|
dests[i].comp_node = inputs[i].comp_node; |
|
|
|
dests[i].layout = inputs[i].layout; |
|
|
|
if (not keepdim && dests[i].layout.ndim > 1) { |
|
|
|
dests[i].layout.remove_axis_inplace(axis); |
|
|
|
} else { |
|
|
|
dests[i].layout.shape[axis] = 1; |
|
|
|
} |
|
|
|
dests[i].layout.init_contiguous_stride(); |
|
|
|
} |
|
|
|
} |
|
|
|
return {output_descs, validated}; |
|
|
|
|
|
|
|
return {dests, true}; |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( |
|
|
|
|