Browse Source

fix onnx bn silu (#5483)

tags/20240820
nihui GitHub 2 years ago
parent
commit
e009c36155
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
5 changed files with 21 additions and 18 deletions
  1. +1
    -1
      tools/pnnx/src/CMakeLists.txt
  2. +9
    -6
      tools/pnnx/src/pass_level2/F_batch_norm.cpp
  3. +2
    -2
      tools/pnnx/src/pass_level5.cpp
  4. +8
    -8
      tools/pnnx/src/pass_level5/fuse_silu.cpp
  5. +1
    -1
      tools/pnnx/src/pass_level5/fuse_silu.h

+ 1
- 1
tools/pnnx/src/CMakeLists.txt View File

@@ -360,6 +360,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_multiheadattention.cpp
pass_level5/fuse_scaled_dot_product_attention.cpp
pass_level5/fuse_select_to_unbind.cpp
pass_level5/fuse_silu.cpp
pass_level5/fuse_slice_copy.cpp
pass_level5/fuse_slice_indices.cpp
pass_level5/fuse_slice_to_tensor_split.cpp
@@ -372,7 +373,6 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_static_instancenorm.cpp
pass_level5/fuse_static_layernorm.cpp
pass_level5/fuse_static_linear.cpp
pass_level5/fuse_swish.cpp
pass_level5/normalize_einsum_equation.cpp
pass_level5/unroll_rnn_op.cpp
)


+ 9
- 6
tools/pnnx/src/pass_level2/F_batch_norm.cpp View File

@@ -100,11 +100,19 @@ pnnx.Output output 1 0 out
{
return "F.batch_norm";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(op, captured_params, captured_attrs);

std::swap(op->inputs[1], op->inputs[3]);
std::swap(op->inputs[2], op->inputs[4]);
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_onnx, 10)

class F_batch_norm_onnx_1 : public GraphRewriterPass
class F_batch_norm_onnx_1 : public F_batch_norm_onnx
{
public:
const char* match_pattern_graph() const
@@ -120,11 +128,6 @@ BatchNormalization op_0 5 1 input weight bias running_mean running_v
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.batch_norm";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_onnx_1, 10)


+ 2
- 2
tools/pnnx/src/pass_level5.cpp View File

@@ -46,6 +46,7 @@
#include "pass_level5/fuse_pad_conv2d.h"
#include "pass_level5/fuse_scaled_dot_product_attention.h"
#include "pass_level5/fuse_select_to_unbind.h"
#include "pass_level5/fuse_silu.h"
#include "pass_level5/fuse_slice_copy.h"
#include "pass_level5/fuse_slice_indices.h"
#include "pass_level5/fuse_slice_to_tensor_split.h"
@@ -58,7 +59,6 @@
#include "pass_level5/fuse_static_instancenorm.h"
#include "pass_level5/fuse_static_layernorm.h"
#include "pass_level5/fuse_static_linear.h"
#include "pass_level5/fuse_swish.h"
#include "pass_level5/normalize_einsum_equation.h"
#include "pass_level4/dead_code_elimination.h"
#include "pass_level4/canonicalize.h"
@@ -144,7 +144,7 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons
fuse_multiheadattention(g);
fuse_scaled_dot_product_attention(g);

fuse_swish(g);
fuse_silu(g);

fuse_index_expression(g);



tools/pnnx/src/pass_level5/fuse_swish.cpp → tools/pnnx/src/pass_level5/fuse_silu.cpp View File

@@ -12,13 +12,13 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_swish.h"
#include "fuse_silu.h"

#include "pass_level2.h"

namespace pnnx {

class fuse_swish_pass : public GraphRewriterPass
class fuse_silu_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
@@ -34,16 +34,16 @@ pnnx.Output output 1 0 out

const char* type_str() const
{
return "F.swish";
return "F.silu";
}

const char* name_str() const
{
return "swish";
return "silu";
}
};

class fuse_swish_pass_1 : public fuse_swish_pass
class fuse_silu_pass_1 : public fuse_silu_pass
{
public:
const char* match_pattern_graph() const
@@ -58,10 +58,10 @@ pnnx.Output output 1 0 out
}
};

void fuse_swish(Graph& graph)
void fuse_silu(Graph& graph)
{
fuse_swish_pass a;
fuse_swish_pass_1 b;
fuse_silu_pass a;
fuse_silu_pass_1 b;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);

tools/pnnx/src/pass_level5/fuse_swish.h → tools/pnnx/src/pass_level5/fuse_silu.h View File

@@ -16,6 +16,6 @@

namespace pnnx {

void fuse_swish(Graph& graph);
void fuse_silu(Graph& graph);

} // namespace pnnx

Loading…
Cancel
Save