| @@ -1579,3 +1579,9 @@ def batched_nms( | |||||
| indices = indices[0][: count.item()] | indices = indices[0][: count.item()] | ||||
| keep_inds = sorted_idx[indices] | keep_inds = sorted_idx[indices] | ||||
| return keep_inds | return keep_inds | ||||
| from .loss import * # isort:skip | |||||
| from .quantized import conv_bias_activation # isort:skip | |||||
| @@ -551,3 +551,5 @@ def test_nms_is_same(): | |||||
| assert op1 != op3 | assert op1 != op3 | ||||
| assert op1 != op4 | assert op1 != op4 | ||||
| assert op3 != op4 | assert op3 != op4 | ||||
| @@ -159,6 +159,7 @@ void Cumsum::init_output_static_infer_desc() { | |||||
| {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace}); | {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace}); | ||||
| } | } | ||||
| /* ================= CondTake ================= */ | /* ================= CondTake ================= */ | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake); | ||||
| @@ -63,4 +63,5 @@ decl_opr('TopK', | |||||
| inputs=['data', 'k'], params='TopK', | inputs=['data', 'k'], params='TopK', | ||||
| desc='Select the top k values from sorted result.') | desc='Select the top k values from sorted result.') | ||||
| # vim: ft=python | # vim: ft=python | ||||
| @@ -70,6 +70,7 @@ namespace opr { | |||||
| using CumsumV1 = opr::Cumsum; | using CumsumV1 = opr::Cumsum; | ||||
| MGB_SEREG_OPR(CumsumV1, 1); | MGB_SEREG_OPR(CumsumV1, 1); | ||||
| } // namespace opr | } // namespace opr | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -94,6 +94,7 @@ MGB_DEFINE_OPR_CLASS(Cumsum, cg::SingleCNOperatorNodeBaseT< | |||||
| void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
| }; | }; | ||||
| namespace intl { | namespace intl { | ||||
| using CondTakeBase = | using CondTakeBase = | ||||
| cg::SingleCNOperatorNode<cg::OperatorNodeBase, | cg::SingleCNOperatorNode<cg::OperatorNodeBase, | ||||
| @@ -28,6 +28,7 @@ table Blob { | |||||
| } | } | ||||
| table Reserved0 {} | table Reserved0 {} | ||||
| table Reserved1 {} | |||||
| union OperatorParam { | union OperatorParam { | ||||
| param.Empty = 1, | param.Empty = 1, | ||||
| @@ -100,6 +101,7 @@ union OperatorParam { | |||||
| param.Remap = 68, | param.Remap = 68, | ||||
| param.NMSKeep = 69, | param.NMSKeep = 69, | ||||
| param.AdaptivePooling = 70, | param.AdaptivePooling = 70, | ||||
| Reserved1 = 71, | |||||
| } | } | ||||
| table Operator { | table Operator { | ||||
| @@ -143,3 +143,4 @@ pdef('PersistentOutputStorage').add_fields( | |||||
| ' no branch is taken') | ' no branch is taken') | ||||
| ) | ) | ||||
| ) | ) | ||||