| @@ -413,7 +413,7 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | |||
| megdnn_throw_if( | |||
| cur_shape != 1 && cur_stride != 0, tensor_reshape_error, | |||
| megdnn_mangle(ssprintf( | |||
| "brodcast on dim with shape not equal to 1: " | |||
| "broadcast on dim with shape not equal to 1: " | |||
| "src_shape=%s dst_shape=%s", | |||
| to_string().c_str(), tshape.to_string().c_str()))); | |||
| result.shape[target_idx] = tshape.shape[target_idx]; | |||
| @@ -47,7 +47,9 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): | |||
| grad_fn = reduce_sum_grad_fn | |||
| else: | |||
| grad_fn = default_grad_fn | |||
| elif isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD: | |||
| elif isinstance(op, Broadcast) or ( | |||
| isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | |||
| ): | |||
| grad_fn = elemwise_add_grad_fn | |||
| else: | |||
| grad_fn = default_grad_fn | |||
| @@ -212,5 +214,4 @@ _oprAttr_grad_fn = { | |||
| Reshape.name: reshape_grad_fn, | |||
| Subtensor.name: subtensor_grad_fn, | |||
| IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn, | |||
| Broadcast.name: elemwise_add_grad_fn, | |||
| } | |||
| @@ -59,29 +59,7 @@ def _transpose(data, axes): | |||
| def _broadcast(inp, shape): | |||
| def valid_broadcast(src, tar): | |||
| def failed(): | |||
| raise ValueError( | |||
| "the input shape {} can not be broadcasted to target shape {}".format( | |||
| src, tar | |||
| ) | |||
| ) | |||
| if isinstance(src, (TensorBase, TensorWrapperBase)): | |||
| src = src.numpy() | |||
| if isinstance(tar, (TensorBase, TensorWrapperBase)): | |||
| tar = tar.numpy() | |||
| if len(src) > len(tar): | |||
| failed() | |||
| for i in range(min(len(src), len(tar))): | |||
| if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: | |||
| failed() | |||
| shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | |||
| valid_broadcast(inp.shape, shape) | |||
| (result,) = apply(builtin.Broadcast(), inp, shape) | |||
| return result | |||
| @@ -21,6 +21,7 @@ | |||
| #include "megbrain/imperative/ops/nms.h" | |||
| #include "megbrain/imperative/ops/elemwise.h" | |||
| #include "megbrain/imperative/ops/batch_norm.h" | |||
| #include "megbrain/imperative/ops/broadcast.h" | |||
| namespace py = pybind11; | |||
| @@ -206,4 +207,7 @@ void init_ops(py::module m) { | |||
| V(INFERENCE); | |||
| #undef V | |||
| py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef>(m, "Broadcast") | |||
| .def(py::init<>()); | |||
| } | |||
| @@ -262,13 +262,13 @@ def test_broadcast(): | |||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||
| x = F.ones((2, 1, 3)) | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(RuntimeError): | |||
| F.broadcast_to(x, (2, 3, 4)) | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(RuntimeError): | |||
| F.broadcast_to(x, (4, 1, 3)) | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(RuntimeError): | |||
| F.broadcast_to(x, (1, 3)) | |||
| @@ -0,0 +1,95 @@ | |||
| /** | |||
| * \file imperative/src/impl/ops/broadcast.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/broadcast.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| namespace { | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| node_->cast_final_safe<opr::Broadcast>(); | |||
| return Broadcast::make(); | |||
| } | |||
| cg::OperatorNodeBase* apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| def.cast_final_safe<Broadcast>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
| return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr(); | |||
| } | |||
| bool valid_broadcast(const TensorShape& src_shape, | |||
| const TensorShape& tar_shape) { | |||
| size_t src_ndim = src_shape.ndim, tar_ndim = tar_shape.ndim; | |||
| if (src_ndim > tar_ndim) { | |||
| return false; | |||
| } | |||
| size_t min_ndim = src_ndim < tar_ndim ? src_ndim : tar_ndim; | |||
| for (size_t i = 0; i < min_ndim; ++i) { | |||
| if (src_shape[src_ndim - i - 1] != 1 && | |||
| src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| def.cast_final_safe<Broadcast>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
| auto&& src = inputs[0]; | |||
| auto&& tshp = inputs[1]; | |||
| TensorLayout out_layout = src.layout; | |||
| if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
| out_layout.ndim = 0; | |||
| return {{out_layout, src.comp_node}}; | |||
| } | |||
| mgb_assert( | |||
| tshp.layout.ndim == 1, | |||
| "target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||
| tshp.layout.ndim); | |||
| size_t target_ndim = tshp.layout.shape[0]; | |||
| out_layout.ndim = target_ndim; | |||
| auto* ptr = tshp.value.ptr<dt_int32>(); | |||
| for(size_t i=0; i<target_ndim; ++i) { | |||
| out_layout.shape[i] = ptr[i]; | |||
| } | |||
| mgb_assert(valid_broadcast(src.layout, out_layout), | |||
| "the input shape %s can not be broadcasted to target shape %s", | |||
| src.layout.TensorShape::to_string().c_str(), | |||
| out_layout.TensorShape::to_string().c_str()); | |||
| return {{out_layout, src.comp_node}}; | |||
| } | |||
| OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast); | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/ops/broadcast.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/op_def.h" | |||
| namespace mgb::imperative { | |||
| class Broadcast : public OpDefImplBase<Broadcast> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| Broadcast() = default; | |||
| size_t hash() const override { | |||
| return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||
| } | |||
| bool is_same_st(const Hashable& rhs) const override { | |||
| return true; | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -32,8 +32,7 @@ public: | |||
| bool is_same_st(const Hashable& rhs_) const override { | |||
| auto&& rhs = static_cast<const NMSKeep&>(rhs_); | |||
| return rhs.dyn_typeinfo() == dyn_typeinfo() | |||
| && rhs.iou_thresh == iou_thresh | |||
| return rhs.iou_thresh == iou_thresh | |||
| && rhs.max_output == max_output; | |||
| } | |||