|
|
|
@@ -67,6 +67,39 @@ pnnx.Output output 1 0 out |
|
|
|
|
|
|
|
REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean, 20) |
|
|
|
|
|
|
|
class torch_mean_1 : public GraphRewriterPass |
|
|
|
{ |
|
|
|
public: |
|
|
|
const char* match_pattern_graph() const |
|
|
|
{ |
|
|
|
return R"PNNXIR(7767517 |
|
|
|
3 2 |
|
|
|
pnnx.Input input 0 1 input |
|
|
|
torch.mean op_0 1 1 input out |
|
|
|
pnnx.Output output 1 0 out |
|
|
|
)PNNXIR"; |
|
|
|
} |
|
|
|
|
|
|
|
const char* type_str() const |
|
|
|
{ |
|
|
|
return "Reduction"; |
|
|
|
} |
|
|
|
|
|
|
|
const char* name_str() const |
|
|
|
{ |
|
|
|
return "mean"; |
|
|
|
} |
|
|
|
|
|
|
|
void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const |
|
|
|
{ |
|
|
|
op->params["0"] = 3; |
|
|
|
op->params["1"] = 1; |
|
|
|
op->params["4"] = 0; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean_1, 20) |
|
|
|
|
|
|
|
} // namespace ncnn |
|
|
|
|
|
|
|
} // namespace pnnx |