|
|
|
@@ -33,7 +33,7 @@ const BaseRef InsertPadForNMSWithMask::DefinePattern() const { |
|
|
|
return VectorRef({prim::kPrimNMSWithMask, Xs}); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr INsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type, |
|
|
|
AnfNodePtr InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type, |
|
|
|
const std::vector<size_t> &origin_shape) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
std::vector<AnfNodePtr> new_pad_inputs; |
|
|
|
@@ -66,7 +66,7 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
origin_shape[1] = 8; |
|
|
|
auto pad = INsertPadToGraph(func_graph, cur_input, origin_type, origin_shape); |
|
|
|
auto pad = InsertPadToGraph(func_graph, cur_input, origin_type, origin_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(pad); |
|
|
|
pad->set_scope(cnode->scope()); |
|
|
|
AnfAlgo::SetNodeAttr("paddings", MakeValue(std::vector<std::vector<int>>{{0, 0}, {0, 3}}), pad); |
|
|
|
|