Browse Source

convert torch mean/sum/prod reduction with no args

tags/20220216
nihuini 4 years ago
parent
commit
1be043aad5
No known key found for this signature in database GPG Key ID: 98FD8F4EBC3E5DB8
3 changed files with 99 additions and 0 deletions
  1. +33
    -0
      tools/pnnx/src/pass_ncnn/torch_mean.cpp
  2. +33
    -0
      tools/pnnx/src/pass_ncnn/torch_prod.cpp
  3. +33
    -0
      tools/pnnx/src/pass_ncnn/torch_sum.cpp

+ 33
- 0
tools/pnnx/src/pass_ncnn/torch_mean.cpp View File

@@ -67,6 +67,39 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean, 20)

class torch_mean_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.mean op_0 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Reduction";
}

const char* name_str() const
{
return "mean";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["0"] = 3;
op->params["1"] = 1;
op->params["4"] = 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean_1, 20)

} // namespace ncnn

} // namespace pnnx

+ 33
- 0
tools/pnnx/src/pass_ncnn/torch_prod.cpp View File

@@ -64,6 +64,39 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_prod, 20)

class torch_prod_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.prod op_0 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Reduction";
}

const char* name_str() const
{
return "prod";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["0"] = 6;
op->params["1"] = 1;
op->params["4"] = 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_prod_1, 20)

} // namespace ncnn

} // namespace pnnx

+ 33
- 0
tools/pnnx/src/pass_ncnn/torch_sum.cpp View File

@@ -67,6 +67,39 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_sum, 20)

class torch_sum_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.sum op_0 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Reduction";
}

const char* name_str() const
{
return "sum";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["0"] = 0;
op->params["1"] = 1;
op->params["4"] = 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_sum_1, 20)

} // namespace ncnn

} // namespace pnnx

Loading…
Cancel
Save