| @@ -12,6 +12,8 @@ | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/graph/helper.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb { | |||
| @@ -83,10 +85,46 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto& input = inputs_tensors[0]; | |||
| TensorShape target_shape; | |||
| cg::copy_tensor_value_to_shape(target_shape, inputs_tensors[1]->get_value().proxy_to_default_cpu()); | |||
| // TODO: memory forward | |||
| // if (input->shape().eq_shape(target_shape)) { | |||
| // return {{{input->layout(), 0, input->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}}; | |||
| // } | |||
| return {{{{target_shape, input->dtype()}, 0, input->comp_node(), StorageIdentifier::make(0)}}, {}}; | |||
| } | |||
| void execute( | |||
| const OpDef& def, | |||
| SmallVector<TensorPtr> inputs, | |||
| SmallVector<TensorPtr> outputs, | |||
| SmallVector<TensorPtr> workspace) { | |||
| if (outputs[0]->layout().is_empty()) { | |||
| return; | |||
| } | |||
| if (inputs[0]->shape().eq_shape(outputs[0]->shape())) { | |||
| mgb_assert(inputs[0]->layout().eq_layout(outputs[0]->layout())); | |||
| // TODO: memory forward | |||
| // mgb_assert(inputs[0]->offset() == outputs[0]->offset()); | |||
| // mgb_assert(inputs[0]->blob() == outputs[0]->blob()); | |||
| outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor()); | |||
| } else { | |||
| TensorLayout input_layout = inputs[0]->layout().broadcast(outputs[0]->shape()); | |||
| outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor().sub(SubTensorSpec::make_from_layout(input_layout))); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .infer_output_mem_desc(infer_output_mem_desc) | |||
| .execute(execute) | |||
| .fallback(); | |||
| } // broadcast | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * \file imperative/src/impl/ops/reduce.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "../op_trait.h" | |||
| #include "../dnn_op_helper.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| namespace { | |||
| namespace reduce { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& reduce = static_cast<const Reduce&>(def); | |||
| OperatorNodeConfig config{reduce.make_name()}; | |||
| if (inputs.size() > 1) { | |||
| return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | |||
| } else { | |||
| return opr::Reduce::make(inputs[0], reduce.param(), | |||
| (cg::VarNode*)nullptr, config); | |||
| } | |||
| } | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| auto* node = &node_->cast_final_safe<opr::Reduce>(); | |||
| return Reduce::make(node->param()); | |||
| } | |||
| OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace reduce | |||
| } // namespace | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -116,31 +116,6 @@ OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace top_k | |||
| } // namespace | |||
| namespace { | |||
| namespace reduce { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& reduce = static_cast<const Reduce&>(def); | |||
| OperatorNodeConfig config{reduce.make_name()}; | |||
| if (inputs.size() > 1) { | |||
| return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | |||
| } else { | |||
| return opr::Reduce::make(inputs[0], reduce.param(), | |||
| (cg::VarNode*)nullptr, config); | |||
| } | |||
| } | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| auto* node = &node_->cast_final_safe<opr::Reduce>(); | |||
| return Reduce::make(node->param()); | |||
| } | |||
| OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace reduce | |||
| } // namespace | |||
| namespace { | |||
| namespace adaptive_pooling { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||