GitOrigin-RevId: 56918db014
tags/v1.0.0-rc1
| @@ -11,6 +11,7 @@ import itertools | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| import megengine.core.ops.builtin as builtin | |||||
| import megengine.core.tensor.dtype as dtype | import megengine.core.tensor.dtype as dtype | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine import Buffer, Parameter, is_cuda_available, tensor | from megengine import Buffer, Parameter, is_cuda_available, tensor | ||||
| @@ -631,3 +632,20 @@ def test_condtake(): | |||||
| val, idx = F.cond_take(yy, xx) | val, idx = F.cond_take(yy, xx) | ||||
| np.testing.assert_equal(val.numpy(), x[y]) | np.testing.assert_equal(val.numpy(), x[y]) | ||||
| np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | ||||
| def test_condtake_is_same(): | |||||
| op1 = builtin.CondTake() | |||||
| op2 = builtin.CondTake() | |||||
| assert op1 == op2 | |||||
| def test_nms_is_same(): | |||||
| op1 = builtin.NMSKeep(0.7, 100) | |||||
| op2 = builtin.NMSKeep(0.7, 100) | |||||
| op3 = builtin.NMSKeep(0.8, 100) | |||||
| op4 = builtin.NMSKeep(0.7, 200) | |||||
| assert op1 == op2 | |||||
| assert op1 != op3 | |||||
| assert op1 != op4 | |||||
| assert op3 != op4 | |||||
| @@ -19,6 +19,15 @@ class CondTake : public OpDefImplBase<CondTake> { | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | MGB_DYN_TYPE_OBJ_FINAL_DECL; | ||||
| public: | public: | ||||
| CondTake() = default; | CondTake() = default; | ||||
| size_t hash() const override { | |||||
| return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||||
| } | |||||
| bool is_same_st(const Hashable& rhs) const override { | |||||
| return rhs.dyn_typeinfo() == dyn_typeinfo(); | |||||
| } | |||||
| }; | }; | ||||
| } // namespace mgb::imperative | } // namespace mgb::imperative | ||||
| @@ -23,6 +23,20 @@ public: | |||||
| NMSKeep() = default; | NMSKeep() = default; | ||||
| NMSKeep(float iou_thresh_, uint32_t max_output_): | NMSKeep(float iou_thresh_, uint32_t max_output_): | ||||
| iou_thresh(iou_thresh_), max_output(max_output_) {} | iou_thresh(iou_thresh_), max_output(max_output_) {} | ||||
| size_t hash() const override { | |||||
| return hash_pair_combine( | |||||
| hash_pair_combine(mgb::hash(iou_thresh), mgb::hash(max_output)), | |||||
| reinterpret_cast<std::uintptr_t>(dyn_typeinfo())); | |||||
| } | |||||
| bool is_same_st(const Hashable& rhs_) const override { | |||||
| auto&& rhs = static_cast<const NMSKeep&>(rhs_); | |||||
| return rhs.dyn_typeinfo() == dyn_typeinfo() | |||||
| && rhs.iou_thresh == iou_thresh | |||||
| && rhs.max_output == max_output; | |||||
| } | |||||
| }; | }; | ||||
| } // namespace mgb::imperative | } // namespace mgb::imperative | ||||