make build possible at 8G ddr env, when -j8
GitOrigin-RevId: d0c442b41d
tags/v1.8.0
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1(2); | |||
| INSTANTIATION_CONV_S1_BIAS(2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1_NO_BIAS(2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2(5); | |||
| INSTANTIATION_CONV_S2_BIAS(2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2_NO_BIAS(2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1(5); | |||
| INSTANTIATION_CONV_S1_BIAS(3); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(3); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1_NO_BIAS(3); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2(2); | |||
| INSTANTIATION_CONV_S2_BIAS(3); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(3); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2_NO_BIAS(3); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1(3); | |||
| INSTANTIATION_CONV_S1_BIAS(5); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(5); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1_NO_BIAS(5); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2(7); | |||
| INSTANTIATION_CONV_S2_BIAS(5); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(5); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2_NO_BIAS(5); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1(7); | |||
| INSTANTIATION_CONV_S1_BIAS(7); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(7); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
| INSTANTIATION_CONV_S1_NO_BIAS(7); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2(3); | |||
| INSTANTIATION_CONV_S2_BIAS(7); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(7); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
| INSTANTIATION_CONV_S2_NO_BIAS(7); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -469,9 +469,12 @@ void conv_bias::conv_direct_fp32_nchw44( | |||
| INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \ | |||
| INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>) | |||
| #define INSTANTIATION_CONV_S1(filter_size) \ | |||
| FOR_OP(filter_size, BiasMode::NO_BIAS) \ | |||
| FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
| FOR_OP(filter_size, BiasMode::BIAS) | |||
| #define INSTANTIATION_CONV_S1_NO_BIAS(filter_size) \ | |||
| FOR_OP(filter_size, BiasMode::NO_BIAS) | |||
| // vim: syntax=cpp.doxygen | |||
| #define INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(filter_size) \ | |||
| FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
| #define INSTANTIATION_CONV_S1_BIAS(filter_size) FOR_OP(filter_size, BiasMode::BIAS) | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -550,9 +550,12 @@ void conv_bias::conv_direct_fp32_nchw44( | |||
| INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \ | |||
| INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>) | |||
| #define INSTANTIATION_CONV_S2(filter_size) \ | |||
| FOR_OP(filter_size, BiasMode::NO_BIAS) \ | |||
| FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
| FOR_OP(filter_size, BiasMode::BIAS) | |||
| #define INSTANTIATION_CONV_S2_NO_BIAS(filter_size) \ | |||
| FOR_OP(filter_size, BiasMode::NO_BIAS) | |||
| // vim: syntax=cpp.doxygen | |||
| #define INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(filter_size) \ | |||
| FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
| #define INSTANTIATION_CONV_S2_BIAS(filter_size) FOR_OP(filter_size, BiasMode::BIAS) | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV(2, 1); | |||
| INSTANCE_CONV_BIAS(2, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_NO_BIAS(2, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV(2, 2); | |||
| INSTANCE_CONV_BIAS(2, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_NO_BIAS(2, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV(3, 1); | |||
| INSTANCE_CONV_BIAS(3, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_NO_BIAS(3, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -11,4 +11,5 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV(3, 2); | |||
| INSTANCE_CONV_BIAS(3, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_NO_BIAS(3, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BIAS(5, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_NO_BIAS(5, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BIAS(5, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_NO_BIAS(5, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BIAS(7, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_NO_BIAS(7, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BIAS(7, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,15 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| INSTANCE_CONV_NO_BIAS(7, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -928,9 +928,11 @@ void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44( | |||
| INSTANTIATION(stride, filter, bias, ReluOp<dt_float32>) \ | |||
| INSTANTIATION(stride, filter, bias, HSwishOp<dt_float32>) | |||
| #define INSTANCE_CONV(filter, stride) \ | |||
| FOR_OP(stride, filter, BiasMode::NO_BIAS) \ | |||
| FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
| FOR_OP(stride, filter, BiasMode::BIAS) | |||
| #define INSTANCE_CONV_NO_BIAS(filter, stride) FOR_OP(stride, filter, BiasMode::NO_BIAS) | |||
| #define INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(filter, stride) \ | |||
| FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
| #define INSTANCE_CONV_BIAS(filter, stride) FOR_OP(stride, filter, BiasMode::BIAS) | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -265,7 +265,8 @@ void conv_direct_sdot_int8_nchw44( | |||
| #define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ | |||
| template void \ | |||
| conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, filter_size>( \ | |||
| megdnn::arm_common::direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44< \ | |||
| dst_type, stride, bias_mode, Op, filter_size>( \ | |||
| dst_type * dst, const int oh, const int ow, const int8_t* src, \ | |||
| const int ih, const int iw, const int8_t* weight, const int32_t* bias, \ | |||
| const int oh_size, const int oc, const int ic, const Op& op); | |||
| @@ -284,22 +285,6 @@ void conv_direct_sdot_int8_nchw44( | |||
| FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||
| FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
| #define FOR_FILTER(stride) \ | |||
| FOR_BIAS(stride, 2) \ | |||
| FOR_BIAS(stride, 3) \ | |||
| FOR_BIAS(stride, 5) \ | |||
| FOR_BIAS(stride, 7) | |||
| FOR_FILTER(1) | |||
| #undef FOR_STRIDE | |||
| #undef FOR_FILTER | |||
| #undef FOR_IC | |||
| #undef FOR_BIAS | |||
| #undef FOR_NONLINEAR | |||
| #undef FOR_REMAIN | |||
| #undef INSTANTIATION | |||
| } // namespace direct_dotprod_nchw44 | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_2x2.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h" | |||
| #if MGB_ENABLE_DOT | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| FOR_BIAS(1, 2); | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_3x3.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h" | |||
| #if MGB_ENABLE_DOT | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| FOR_BIAS(1, 3); | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_5x5.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h" | |||
| #if MGB_ENABLE_DOT | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| FOR_BIAS(1, 5); | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_7x7.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h" | |||
| #if MGB_ENABLE_DOT | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| FOR_BIAS(1, 7); | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -266,7 +266,8 @@ void conv_direct_sdot_int8_nchw44( | |||
| #define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ | |||
| template void \ | |||
| conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, filter_size>( \ | |||
| megdnn::arm_common::direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44< \ | |||
| dst_type, stride, bias_mode, Op, filter_size>( \ | |||
| dst_type * dst, const int oh, const int ow, const int8_t* src, \ | |||
| const int ih, const int iw, const int8_t* weight, const int32_t* bias, \ | |||
| const int oh_size, const int oc, const int ic, const Op& op); | |||
| @@ -285,22 +286,6 @@ void conv_direct_sdot_int8_nchw44( | |||
| FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||
| FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
| #define FOR_FILTER(stride) \ | |||
| FOR_BIAS(stride, 2) \ | |||
| FOR_BIAS(stride, 3) \ | |||
| FOR_BIAS(stride, 5) \ | |||
| FOR_BIAS(stride, 7) | |||
| FOR_FILTER(2) | |||
| #undef FOR_STRIDE | |||
| #undef FOR_FILTER | |||
| #undef FOR_IC | |||
| #undef FOR_BIAS | |||
| #undef FOR_NONLINEAR | |||
| #undef FOR_REMAIN | |||
| #undef INSTANTIATION | |||
| } // namespace direct_dotprod_nchw44 | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_2x2.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h" | |||
| #if MGB_ENABLE_DOT | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| FOR_BIAS(2, 2); | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_3x3.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h" | |||
| #if MGB_ENABLE_DOT | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| FOR_BIAS(2, 3); | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_5x5.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h" | |||
| #if MGB_ENABLE_DOT | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| FOR_BIAS(2, 5); | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_7x7.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h" | |||
| #if MGB_ENABLE_DOT | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| FOR_BIAS(2, 7); | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -45,4 +45,4 @@ public: | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -13,336 +13,9 @@ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| namespace { | |||
| /** | |||
| * @brief core code for calculation patten | |||
| * | |||
| * @tparam src_idx is offset of src reg | |||
| * @tparam weight_idx is offset of weight reg | |||
| * @tparam c_dim is output channel | |||
| * @tparam Func mla operation funcion | |||
| * @tparam stride | |||
| * @tparam T outpur regs type | |||
| * @tparam T2 src regs type | |||
| * @tparam T3 weight regs type | |||
| * @tparam T4 temp regs type | |||
| */ | |||
| template < | |||
| int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, | |||
| typename T3, typename T4> | |||
| struct ShiftCalHelper { | |||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp); | |||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | |||
| }; | |||
| template < | |||
| int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, | |||
| typename T3, typename T4> | |||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) { | |||
| ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, T4>::impl( | |||
| c, src, weight, temp); | |||
| } | |||
| template < | |||
| int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, | |||
| typename T3> | |||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | |||
| ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, int>::impl( | |||
| c, src, weight); | |||
| }; | |||
| template < | |||
| int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> | |||
| struct ShiftCalHelper<src_idx, weight_idx, 2, 1, T, T2, T3, T4> { | |||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { | |||
| c[0][0] = vdotq_s32_h( | |||
| src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]); | |||
| c[1][0] = vdotq_s32_h( | |||
| src[(0 + src_idx) % 8], weight[1][weight_idx], c[1][0], temp[1]); | |||
| c[0][1] = vdotq_s32_h( | |||
| src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[2]); | |||
| c[1][1] = vdotq_s32_h( | |||
| src[(1 + src_idx) % 8], weight[1][weight_idx], c[1][1], temp[3]); | |||
| c[0][2] = vdotq_s32_h( | |||
| src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[0]); | |||
| c[1][2] = vdotq_s32_h( | |||
| src[(2 + src_idx) % 8], weight[1][weight_idx], c[1][2], temp[1]); | |||
| c[0][3] = vdotq_s32_h( | |||
| src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[2]); | |||
| c[1][3] = vdotq_s32_h( | |||
| src[(3 + src_idx) % 8], weight[1][weight_idx], c[1][3], temp[3]); | |||
| c[0][4] = vdotq_s32_h( | |||
| src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]); | |||
| c[1][4] = vdotq_s32_h( | |||
| src[(4 + src_idx) % 8], weight[1][weight_idx], c[1][4], temp[1]); | |||
| c[0][5] = vdotq_s32_h( | |||
| src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[2]); | |||
| c[1][5] = vdotq_s32_h( | |||
| src[(5 + src_idx) % 8], weight[1][weight_idx], c[1][5], temp[3]); | |||
| c[0][6] = vdotq_s32_h( | |||
| src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[0]); | |||
| c[1][6] = vdotq_s32_h( | |||
| src[(6 + src_idx) % 8], weight[1][weight_idx], c[1][6], temp[1]); | |||
| c[0][7] = vdotq_s32_h( | |||
| src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[2]); | |||
| c[1][7] = vdotq_s32_h( | |||
| src[(7 + src_idx) % 8], weight[1][weight_idx], c[1][7], temp[3]); | |||
| } | |||
| static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); | |||
| }; | |||
| template < | |||
| int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> | |||
| struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> { | |||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { | |||
| c[0][0] = vdotq_s32_h( | |||
| src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]); | |||
| c[0][1] = vdotq_s32_h( | |||
| src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[1]); | |||
| c[0][2] = vdotq_s32_h( | |||
| src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[2]); | |||
| c[0][3] = vdotq_s32_h( | |||
| src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[3]); | |||
| c[0][4] = vdotq_s32_h( | |||
| src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]); | |||
| c[0][5] = vdotq_s32_h( | |||
| src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[1]); | |||
| c[0][6] = vdotq_s32_h( | |||
| src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[2]); | |||
| c[0][7] = vdotq_s32_h( | |||
| src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[3]); | |||
| } | |||
| static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, 1> { | |||
| static void impl( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, | |||
| int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_height = 1; | |||
| constexpr int filter_width = 4; | |||
| constexpr int oc_step = 4; | |||
| constexpr int loop_ic_step = 1; | |||
| constexpr int simd_len = 16; | |||
| constexpr int pack_iw_len = 16; | |||
| constexpr int src_reg = 8; | |||
| constexpr int weight_reg = 1; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||
| int16x8_t temp_c[4]; | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| weight_ptr += oc_step * filter_height * filter_width; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> { | |||
| static void impl( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, | |||
| int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_height = 2; | |||
| constexpr int filter_width = 4; | |||
| constexpr int oc_step = 4; | |||
| constexpr int loop_ic_step = 1; | |||
| constexpr int simd_len = 16; | |||
| constexpr int pack_iw_len = 16; | |||
| constexpr int src_reg = 8; | |||
| constexpr int weight_reg = 1; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||
| int16x8_t temp_c[4]; | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| weight_ptr += oc_step * filter_height * filter_width; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, 1> { | |||
| static void impl( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, | |||
| int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_height = 3; | |||
| constexpr int filter_width = 4; | |||
| constexpr int oc_step = 4; | |||
| constexpr int loop_ic_step = 1; | |||
| constexpr int simd_len = 16; | |||
| constexpr int pack_iw_len = 16; | |||
| constexpr int src_reg = 8; | |||
| constexpr int weight_reg = 1; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||
| int16x8_t temp_c[4]; | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr + 2 * filter_width * oc_step, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 2 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| weight_ptr += oc_step * filter_height * filter_width; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, 1> { | |||
| static void impl( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, | |||
| int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_height = 5; | |||
| constexpr int filter_width = 8; | |||
| constexpr int oc_step = 4; | |||
| constexpr int loop_ic_step = 1; | |||
| constexpr int simd_len = 16; | |||
| constexpr int pack_iw_len = 16; | |||
| constexpr int src_reg = 8; | |||
| constexpr int weight_reg = 2; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||
| int16x8_t temp_c[4]; | |||
| #define cb(step) \ | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||
| dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \ | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||
| src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ | |||
| load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ | |||
| src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \ | |||
| cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| UNROLL_CALL_RAW(5, cb); | |||
| #undef cb | |||
| weight_ptr += oc_step * filter_height * filter_width; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, 1> { | |||
| static void impl( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, | |||
| int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_height = 7; | |||
| constexpr int filter_width = 8; | |||
| constexpr int oc_step = 4; | |||
| constexpr int loop_ic_step = 1; | |||
| constexpr int simd_len = 16; | |||
| constexpr int pack_iw_len = 16; | |||
| constexpr int src_reg = 8; | |||
| constexpr int weight_reg = 2; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||
| int16x8_t temp_c[4]; | |||
| #define cb(step) \ | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||
| dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \ | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||
| src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ | |||
| load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ | |||
| src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \ | |||
| cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| UNROLL_CALL_RAW(7, cb); | |||
| #undef cb | |||
| weight_ptr += oc_step * filter_height * filter_width; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| } // namespace | |||
| namespace int8_direct_nchw_nchw44 { | |||
| /** | |||
| * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)} | |||
| @@ -444,115 +117,9 @@ void pack_nchw_src_for_nchw44_conv<1>( | |||
| } | |||
| } | |||
| template <BiasMode bias_mode, typename Op, size_t filter_size> | |||
| struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> { | |||
| static void impl( | |||
| const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, | |||
| int8_t* dst, const size_t oc, const size_t ic, const size_t ih, | |||
| const size_t iw, const size_t oh, const size_t ow, const Op& op) { | |||
| MEGDNN_MARK_USED_VAR(temp); | |||
| constexpr int stride = 1; | |||
| constexpr size_t fh = filter_size; | |||
| constexpr size_t fw = (filter_size + 3) / 4 * 4; | |||
| constexpr size_t ic_step = 1; | |||
| constexpr size_t big_oc_step = 8; | |||
| constexpr size_t oc_step = 4; | |||
| constexpr size_t ih_step = 1; | |||
| constexpr size_t oh_step = 1; | |||
| constexpr size_t ow_step = 8; | |||
| constexpr size_t stride_h = stride; | |||
| constexpr size_t stride_w = stride; | |||
| constexpr int pack_iw_len = 16; | |||
| const size_t img_stride = oh * ow; | |||
| const size_t ow_end = ow / ow_step * ow_step; | |||
| const size_t ow_remain = ow - ow_end; | |||
| const size_t oc_end = oc / big_oc_step * big_oc_step; | |||
| const size_t oc_remain = oc - oc_end; | |||
| const int ld_dst_oc = oc_step * img_stride; | |||
| using remain_fun = std::function<void( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, | |||
| int ld_dst_oc, const Op& op)>; | |||
| remain_fun kern_big_oc_remain = nullptr; | |||
| remain_fun kern_small_oc_remain = nullptr; | |||
| switch (ow_remain) { | |||
| #define cb(step) \ | |||
| case step: \ | |||
| kern_big_oc_remain = KerNeonXXs2NchwNchw44< \ | |||
| bias_mode, Op, step, filter_size, big_oc_step, stride>::impl; \ | |||
| kern_small_oc_remain = KerNeonXXs2NchwNchw44< \ | |||
| bias_mode, Op, step, filter_size, oc_step, stride>::impl; \ | |||
| break; | |||
| UNROLL_CALL_RAW(8, cb); | |||
| default: | |||
| megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||
| } | |||
| for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
| const size_t src_offset = | |||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||
| ic_step * pack_iw_len; | |||
| const size_t dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
| KerNeonXXs2NchwNchw44< | |||
| bias_mode, Op, ow_step, filter_size, big_oc_step, stride>:: | |||
| impl(src + src_offset, filter + weight_offset, | |||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, | |||
| op); | |||
| } | |||
| if (ow_remain > 0) { | |||
| const size_t src_offset = | |||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||
| ic_step * pack_iw_len; | |||
| const size_t dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
| kern_big_oc_remain( | |||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||
| dst + dst_offset, ic, ih, iw, ld_dst_oc, op); | |||
| } | |||
| } | |||
| } | |||
| if (oc_remain > 0) { | |||
| size_t oc_idx = oc_end; | |||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
| const size_t src_offset = | |||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||
| ic_step * pack_iw_len; | |||
| const size_t dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
| KerNeonXXs2NchwNchw44< | |||
| bias_mode, Op, ow_step, filter_size, oc_step, stride>:: | |||
| impl(src + src_offset, filter + weight_offset, | |||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, | |||
| op); | |||
| } | |||
| if (ow_remain > 0) { | |||
| const size_t src_offset = | |||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||
| ic_step * pack_iw_len; | |||
| const size_t dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
| kern_small_oc_remain( | |||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||
| dst + dst_offset, ic, ih, iw, ld_dst_oc, op); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| #define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ | |||
| template struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, stride>; | |||
| template struct megdnn::arm_common::int8_direct_nchw_nchw44:: \ | |||
| ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, stride>; | |||
| #define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ | |||
| INSTANCE_CONV_KERN_FUN( \ | |||
| @@ -566,17 +133,10 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> { | |||
| INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||
| INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
| #define INSTANCE_CONV_KERN(stride) \ | |||
| INSTANCE_BIAS_MODE_PARAM(stride, 1) \ | |||
| INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | |||
| INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | |||
| INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | |||
| INSTANCE_BIAS_MODE_PARAM(stride, 7) | |||
| INSTANCE_CONV_KERN(1); | |||
| #define INSTANCE_CONV_KERN(stride, filter) INSTANCE_BIAS_MODE_PARAM(stride, filter) | |||
| } // namespace int8_direct_nchw_nchw44 | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,481 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| namespace { | |||
| /** | |||
| * @brief core code for calculation patten | |||
| * | |||
| * @tparam src_idx is offset of src reg | |||
| * @tparam weight_idx is offset of weight reg | |||
| * @tparam c_dim is output channel | |||
| * @tparam Func mla operation funcion | |||
| * @tparam stride | |||
| * @tparam T outpur regs type | |||
| * @tparam T2 src regs type | |||
| * @tparam T3 weight regs type | |||
| * @tparam T4 temp regs type | |||
| */ | |||
| template < | |||
| int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, | |||
| typename T3, typename T4> | |||
| struct ShiftCalHelper { | |||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp); | |||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | |||
| }; | |||
| template < | |||
| int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, | |||
| typename T3, typename T4> | |||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) { | |||
| ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, T4>::impl( | |||
| c, src, weight, temp); | |||
| } | |||
| template < | |||
| int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, | |||
| typename T3> | |||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | |||
| ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, int>::impl( | |||
| c, src, weight); | |||
| }; | |||
| template < | |||
| int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> | |||
| struct ShiftCalHelper<src_idx, weight_idx, 2, 1, T, T2, T3, T4> { | |||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { | |||
| c[0][0] = vdotq_s32_h( | |||
| src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]); | |||
| c[1][0] = vdotq_s32_h( | |||
| src[(0 + src_idx) % 8], weight[1][weight_idx], c[1][0], temp[1]); | |||
| c[0][1] = vdotq_s32_h( | |||
| src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[2]); | |||
| c[1][1] = vdotq_s32_h( | |||
| src[(1 + src_idx) % 8], weight[1][weight_idx], c[1][1], temp[3]); | |||
| c[0][2] = vdotq_s32_h( | |||
| src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[0]); | |||
| c[1][2] = vdotq_s32_h( | |||
| src[(2 + src_idx) % 8], weight[1][weight_idx], c[1][2], temp[1]); | |||
| c[0][3] = vdotq_s32_h( | |||
| src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[2]); | |||
| c[1][3] = vdotq_s32_h( | |||
| src[(3 + src_idx) % 8], weight[1][weight_idx], c[1][3], temp[3]); | |||
| c[0][4] = vdotq_s32_h( | |||
| src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]); | |||
| c[1][4] = vdotq_s32_h( | |||
| src[(4 + src_idx) % 8], weight[1][weight_idx], c[1][4], temp[1]); | |||
| c[0][5] = vdotq_s32_h( | |||
| src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[2]); | |||
| c[1][5] = vdotq_s32_h( | |||
| src[(5 + src_idx) % 8], weight[1][weight_idx], c[1][5], temp[3]); | |||
| c[0][6] = vdotq_s32_h( | |||
| src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[0]); | |||
| c[1][6] = vdotq_s32_h( | |||
| src[(6 + src_idx) % 8], weight[1][weight_idx], c[1][6], temp[1]); | |||
| c[0][7] = vdotq_s32_h( | |||
| src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[2]); | |||
| c[1][7] = vdotq_s32_h( | |||
| src[(7 + src_idx) % 8], weight[1][weight_idx], c[1][7], temp[3]); | |||
| } | |||
| static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); | |||
| }; | |||
| template < | |||
| int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> | |||
| struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> { | |||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { | |||
| c[0][0] = vdotq_s32_h( | |||
| src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]); | |||
| c[0][1] = vdotq_s32_h( | |||
| src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[1]); | |||
| c[0][2] = vdotq_s32_h( | |||
| src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[2]); | |||
| c[0][3] = vdotq_s32_h( | |||
| src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[3]); | |||
| c[0][4] = vdotq_s32_h( | |||
| src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]); | |||
| c[0][5] = vdotq_s32_h( | |||
| src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[1]); | |||
| c[0][6] = vdotq_s32_h( | |||
| src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[2]); | |||
| c[0][7] = vdotq_s32_h( | |||
| src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[3]); | |||
| } | |||
| static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, 1> { | |||
| static void impl( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, | |||
| int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_height = 1; | |||
| constexpr int filter_width = 4; | |||
| constexpr int oc_step = 4; | |||
| constexpr int loop_ic_step = 1; | |||
| constexpr int simd_len = 16; | |||
| constexpr int pack_iw_len = 16; | |||
| constexpr int src_reg = 8; | |||
| constexpr int weight_reg = 1; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||
| int16x8_t temp_c[4]; | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| weight_ptr += oc_step * filter_height * filter_width; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> { | |||
| static void impl( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, | |||
| int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_height = 2; | |||
| constexpr int filter_width = 4; | |||
| constexpr int oc_step = 4; | |||
| constexpr int loop_ic_step = 1; | |||
| constexpr int simd_len = 16; | |||
| constexpr int pack_iw_len = 16; | |||
| constexpr int src_reg = 8; | |||
| constexpr int weight_reg = 1; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||
| int16x8_t temp_c[4]; | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| weight_ptr += oc_step * filter_height * filter_width; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, 1> { | |||
| static void impl( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, | |||
| int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_height = 3; | |||
| constexpr int filter_width = 4; | |||
| constexpr int oc_step = 4; | |||
| constexpr int loop_ic_step = 1; | |||
| constexpr int simd_len = 16; | |||
| constexpr int pack_iw_len = 16; | |||
| constexpr int src_reg = 8; | |||
| constexpr int weight_reg = 1; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||
| int16x8_t temp_c[4]; | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| dot4_weight, weight_ptr + 2 * filter_width * oc_step, ld_weight_oc); | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, nchw_src_ptr + 2 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| weight_ptr += oc_step * filter_height * filter_width; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, 1> { | |||
| static void impl( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, | |||
| int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_height = 5; | |||
| constexpr int filter_width = 8; | |||
| constexpr int oc_step = 4; | |||
| constexpr int loop_ic_step = 1; | |||
| constexpr int simd_len = 16; | |||
| constexpr int pack_iw_len = 16; | |||
| constexpr int src_reg = 8; | |||
| constexpr int weight_reg = 2; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||
| int16x8_t temp_c[4]; | |||
| #define cb(step) \ | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||
| dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \ | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||
| src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ | |||
| load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ | |||
| src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \ | |||
| cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| UNROLL_CALL_RAW(5, cb); | |||
| #undef cb | |||
| weight_ptr += oc_step * filter_height * filter_width; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, 1> { | |||
| static void impl( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, | |||
| int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_height = 7; | |||
| constexpr int filter_width = 8; | |||
| constexpr int oc_step = 4; | |||
| constexpr int loop_ic_step = 1; | |||
| constexpr int simd_len = 16; | |||
| constexpr int pack_iw_len = 16; | |||
| constexpr int src_reg = 8; | |||
| constexpr int weight_reg = 2; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||
| int16x8_t temp_c[4]; | |||
| #define cb(step) \ | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||
| dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \ | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||
| src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ | |||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ | |||
| load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ | |||
| src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \ | |||
| cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
| UNROLL_CALL_RAW(7, cb); | |||
| #undef cb | |||
| weight_ptr += oc_step * filter_height * filter_width; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| } // namespace | |||
| namespace int8_direct_nchw_nchw44 { | |||
| /** | |||
| * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)} | |||
| * pack interleave two adjacent row in filter to one row | |||
| * */ | |||
| template <BiasMode bias_mode, typename Op, size_t filter_size> | |||
| struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> { | |||
| static void impl( | |||
| const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, | |||
| int8_t* dst, const size_t oc, const size_t ic, const size_t ih, | |||
| const size_t iw, const size_t oh, const size_t ow, const Op& op) { | |||
| MEGDNN_MARK_USED_VAR(temp); | |||
| constexpr int stride = 1; | |||
| constexpr size_t fh = filter_size; | |||
| constexpr size_t fw = (filter_size + 3) / 4 * 4; | |||
| constexpr size_t ic_step = 1; | |||
| constexpr size_t big_oc_step = 8; | |||
| constexpr size_t oc_step = 4; | |||
| constexpr size_t ih_step = 1; | |||
| constexpr size_t oh_step = 1; | |||
| constexpr size_t ow_step = 8; | |||
| constexpr size_t stride_h = stride; | |||
| constexpr size_t stride_w = stride; | |||
| constexpr int pack_iw_len = 16; | |||
| const size_t img_stride = oh * ow; | |||
| const size_t ow_end = ow / ow_step * ow_step; | |||
| const size_t ow_remain = ow - ow_end; | |||
| const size_t oc_end = oc / big_oc_step * big_oc_step; | |||
| const size_t oc_remain = oc - oc_end; | |||
| const int ld_dst_oc = oc_step * img_stride; | |||
| using remain_fun = std::function<void( | |||
| const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, | |||
| int ld_dst_oc, const Op& op)>; | |||
| remain_fun kern_big_oc_remain = nullptr; | |||
| remain_fun kern_small_oc_remain = nullptr; | |||
| switch (ow_remain) { | |||
| #define cb(step) \ | |||
| case step: \ | |||
| kern_big_oc_remain = KerNeonXXs2NchwNchw44< \ | |||
| bias_mode, Op, step, filter_size, big_oc_step, stride>::impl; \ | |||
| kern_small_oc_remain = KerNeonXXs2NchwNchw44< \ | |||
| bias_mode, Op, step, filter_size, oc_step, stride>::impl; \ | |||
| break; | |||
| UNROLL_CALL_RAW(8, cb); | |||
| default: | |||
| megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||
| } | |||
| for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
| const size_t src_offset = | |||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||
| ic_step * pack_iw_len; | |||
| const size_t dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
| KerNeonXXs2NchwNchw44< | |||
| bias_mode, Op, ow_step, filter_size, big_oc_step, stride>:: | |||
| impl(src + src_offset, filter + weight_offset, | |||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, | |||
| op); | |||
| } | |||
| if (ow_remain > 0) { | |||
| const size_t src_offset = | |||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||
| ic_step * pack_iw_len; | |||
| const size_t dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
| kern_big_oc_remain( | |||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||
| dst + dst_offset, ic, ih, iw, ld_dst_oc, op); | |||
| } | |||
| } | |||
| } | |||
| if (oc_remain > 0) { | |||
| size_t oc_idx = oc_end; | |||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
| const size_t src_offset = | |||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||
| ic_step * pack_iw_len; | |||
| const size_t dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
| KerNeonXXs2NchwNchw44< | |||
| bias_mode, Op, ow_step, filter_size, oc_step, stride>:: | |||
| impl(src + src_offset, filter + weight_offset, | |||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, | |||
| op); | |||
| } | |||
| if (ow_remain > 0) { | |||
| const size_t src_offset = | |||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||
| ic_step * pack_iw_len; | |||
| const size_t dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
| kern_small_oc_remain( | |||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||
| dst + dst_offset, ic, ih, iw, ld_dst_oc, op); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| #define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ | |||
| template struct megdnn::arm_common::int8_direct_nchw_nchw44:: \ | |||
| ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, stride>; | |||
| #define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ | |||
| INSTANCE_CONV_KERN_FUN( \ | |||
| stride, filter, bias_mode, TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
| INSTANCE_CONV_KERN_FUN( \ | |||
| stride, filter, bias_mode, ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
| INSTANCE_CONV_KERN_FUN( \ | |||
| stride, filter, bias_mode, HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||
| #define INSTANCE_BIAS_MODE_PARAM(stride, filter) \ | |||
| INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||
| INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
| #define INSTANCE_CONV_KERN(stride, filter) INSTANCE_BIAS_MODE_PARAM(stride, filter) | |||
| } // namespace int8_direct_nchw_nchw44 | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,19 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_1x1.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| INSTANCE_CONV_KERN(1, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,19 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_2x2.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| INSTANCE_CONV_KERN(1, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,19 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_3x3.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| INSTANCE_CONV_KERN(1, 3); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,19 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_5x5.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| INSTANCE_CONV_KERN(1, 5); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,19 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_7x7.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| INSTANCE_CONV_KERN(1, 7); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp | |||
| * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s1.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -12,8 +12,5 @@ | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" | |||
| INSTANCE_CONV(2, 1); | |||
| INSTANCE_CONV(3, 1); | |||
| INSTANCE_CONV(5, 1); | |||
| INSTANCE_CONV(7, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp | |||
| * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s2.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -12,8 +12,5 @@ | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" | |||
| INSTANCE_CONV(2, 2); | |||
| INSTANCE_CONV(3, 2); | |||
| INSTANCE_CONV(5, 2); | |||
| INSTANCE_CONV(7, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,16 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s1.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" | |||
| INSTANCE_CONV(3, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,16 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s2.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" | |||
| INSTANCE_CONV(3, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp | |||
| * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s1.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -10,5 +10,7 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| #include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" | |||
| INSTANCE_CONV(5, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp | |||
| * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s2.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -10,5 +10,7 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| #include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" | |||
| INSTANCE_CONV(5, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp | |||
| * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s1.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -10,5 +10,7 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| #include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" | |||
| INSTANCE_CONV(7, 1); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,6 +1,6 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp | |||
| * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s2.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -10,5 +10,7 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
| #include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" | |||
| INSTANCE_CONV(7, 2); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,275 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/opr_binary_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "src/common/elemwise/kern_defs.cuh" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_binary) | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename dtype, uint32_t mode> | |||
| void ElemwiseImpl::binary_kern(const ElemwiseOpParamN<2>& param) { | |||
| using ctype = typename DTypeTrait<dtype>::ctype; | |||
| using Kern = ElemwiseKern<megcorePlatformCPU, mode, ctype>; | |||
| MIDOUT_BEGIN(megdnn_fallback_elemwise_binary, ctype, midout_iv(mode)) { | |||
| if (param.max_ndim == 1) { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), | |||
| midout_iv(1)) { | |||
| auto tot = param.size; | |||
| auto as = param[0].layout.stride[0], bs = param[1].layout.stride[0]; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto dst_tensor = *m_dst; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr()); | |||
| ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| for (size_t i = 0; i < tot; ++i) { | |||
| dst[i] = Kern::apply(a[i * as], b[i * bs]); | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| if (std::min(param[0].layout.ndim, param[1].layout.ndim) > 1) { | |||
| return naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); | |||
| } | |||
| if (param.max_ndim == 2) { | |||
| if (param[0].layout.ndim == 1) { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), | |||
| midout_iv(21)) { | |||
| auto as = param[0].layout.stride[0], | |||
| bs0 = param[1].layout.stride[0], | |||
| bs1 = param[1].layout.stride[1]; | |||
| auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1]; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto dst_tensor = *m_dst; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr()); | |||
| ctype* __restrict dst = | |||
| static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| ptrdiff_t toff = 0; | |||
| for (size_t i = 0; i < n0; ++i) { | |||
| for (size_t j = 0; j < n1; ++j) { | |||
| dst[toff] = | |||
| Kern::apply(a[as * toff], b[bs0 * i + bs1 * j]); | |||
| ++toff; | |||
| } | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), | |||
| midout_iv(22)) { | |||
| megdnn_assert(param[1].layout.ndim == 1); | |||
| auto bs = param[1].layout.stride[0], as0 = param[0].layout.stride[0], | |||
| as1 = param[0].layout.stride[1]; | |||
| auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1]; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto dst_tensor = *m_dst; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr()); | |||
| ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| ptrdiff_t toff = 0; | |||
| for (size_t i = 0; i < n0; ++i) { | |||
| for (size_t j = 0; j < n1; ++j) { | |||
| dst[toff] = Kern::apply(a[as0 * i + as1 * j], b[toff * bs]); | |||
| ++toff; | |||
| } | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| if (param.max_ndim == 3) { | |||
| auto brd_101 = [](const TensorND& t) { | |||
| auto&& l = t.layout; | |||
| return l.ndim == 3 && l.stride[0] == 0 && l.stride[2] == 0; | |||
| }; | |||
| if (param[0].layout.ndim == 1 && brd_101(param[1])) { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), | |||
| midout_iv(31)) { | |||
| auto as = param[0].layout.stride[0], bs = param[1].layout.stride[1]; | |||
| auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1], | |||
| n2 = param[1].layout.shape[2]; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto dst_tensor = *m_dst; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr()); | |||
| ctype* __restrict dst = | |||
| static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| size_t toff = 0; | |||
| for (size_t i = 0; i < n0; ++i) { | |||
| for (size_t j = 0; j < n1; ++j) { | |||
| for (size_t k = 0; k < n2; ++k) { | |||
| dst[toff] = Kern::apply(a[as * toff], b[bs * j]); | |||
| ++toff; | |||
| } | |||
| } | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| if (param[1].layout.ndim == 1 && brd_101(param[0])) { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), | |||
| midout_iv(32)) { | |||
| auto as = param[0].layout.stride[1], bs = param[1].layout.stride[0]; | |||
| auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1], | |||
| n2 = param[0].layout.shape[2]; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto dst_tensor = *m_dst; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr()); | |||
| ctype* __restrict dst = | |||
| static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| size_t toff = 0; | |||
| for (size_t i = 0; i < n0; ++i) { | |||
| for (size_t j = 0; j < n1; ++j) { | |||
| for (size_t k = 0; k < n2; ++k) { | |||
| dst[toff] = Kern::apply(a[as * j], b[bs * toff]); | |||
| ++toff; | |||
| } | |||
| } | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| } | |||
| naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| #define SWITCH_DTYPE(_cat, _cb) \ | |||
| switch (m_dst->layout.dtype.enumv()) { \ | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_##_cat(_cb) default \ | |||
| : megdnn_throw("bad dtype"); \ | |||
| } | |||
| template <uint32_t mode> | |||
| void ElemwiseImpl::exec_BINARY_INT() { | |||
| auto param = make_elemwise_op_param<2>(); | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| return binary_kern<_dt, mode>(param); | |||
| SWITCH_DTYPE(INT, cb) | |||
| #undef cb | |||
| } | |||
| template <uint32_t mode> | |||
| void ElemwiseImpl::exec_BINARY_FLOAT() { | |||
| auto param = make_elemwise_op_param<2>(); | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| return binary_kern<_dt, mode>(param); | |||
| SWITCH_DTYPE(FLOAT, cb) | |||
| #undef cb | |||
| } | |||
| #undef SWITCH_DTYPE | |||
| #undef SWITCH_DTYPE | |||
| using Mode = param_enumv::Elemwise::Mode; | |||
| #define INST(mode) template void megdnn::fallback::ElemwiseImpl::exec_BINARY_INT<mode>() | |||
| INST(Mode::ABS_GRAD); | |||
| INST(Mode::ADD); | |||
| INST(Mode::FLOOR_DIV); | |||
| INST(Mode::MAX); | |||
| INST(Mode::MIN); | |||
| INST(Mode::MOD); | |||
| INST(Mode::MUL); | |||
| INST(Mode::SIGMOID_GRAD); | |||
| INST(Mode::SUB); | |||
| INST(Mode::SWITCH_GT0); | |||
| INST(Mode::TANH_GRAD); | |||
| INST(Mode::LT); | |||
| INST(Mode::LEQ); | |||
| INST(Mode::EQ); | |||
| INST(Mode::SHL); | |||
| INST(Mode::SHR); | |||
| INST(Mode::FUSE_ADD_RELU); | |||
| INST(Mode::RMULH); | |||
| #undef INST | |||
| #define INST(mode) \ | |||
| template void megdnn::fallback::ElemwiseImpl::exec_BINARY_FLOAT<mode>() | |||
| INST(Mode::ABS_GRAD); | |||
| INST(Mode::ADD); | |||
| INST(Mode::FLOOR_DIV); | |||
| INST(Mode::MAX); | |||
| INST(Mode::MIN); | |||
| INST(Mode::MOD); | |||
| INST(Mode::MUL); | |||
| INST(Mode::POW); | |||
| INST(Mode::SIGMOID_GRAD); | |||
| INST(Mode::SUB); | |||
| INST(Mode::SWITCH_GT0); | |||
| INST(Mode::TANH_GRAD); | |||
| INST(Mode::TRUE_DIV); | |||
| INST(Mode::LOG_SUM_EXP); | |||
| INST(Mode::LT); | |||
| INST(Mode::LEQ); | |||
| INST(Mode::EQ); | |||
| INST(Mode::FUSE_ADD_RELU); | |||
| INST(Mode::FUSE_ADD_SIGMOID); | |||
| INST(Mode::FUSE_ADD_TANH); | |||
| INST(Mode::FAST_TANH_GRAD); | |||
| INST(Mode::ATAN2); | |||
| INST(Mode::H_SWISH_GRAD); | |||
| INST(Mode::FUSE_ADD_H_SWISH); | |||
| INST(Mode::SILU_GRAD); | |||
| INST(Mode::GELU_GRAD); | |||
| #undef INST | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -16,8 +16,6 @@ | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_unary) | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_binary) | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_INT) | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_FLOAT) | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT) | |||
| @@ -26,200 +24,6 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename dtype, uint32_t mode> | |||
| void ElemwiseImpl::unary_kern(const ElemwiseOpParamN<1>& param) { | |||
| using ctype = typename DTypeTrait<dtype>::ctype; | |||
| using Kern = ElemwiseKern<megcorePlatformCPU, mode, ctype>; | |||
| MIDOUT_BEGIN(megdnn_fallback_elemwise_unary, ctype, midout_iv(mode)) { | |||
| // only specialize for the most common 1-dim case | |||
| auto tot = param.size; | |||
| auto stride = param[0].layout.stride[0]; | |||
| auto src0 = param[0]; | |||
| auto dst_tensor = *m_dst; | |||
| if (param.max_ndim == 1) { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_unary, ctype, midout_iv(mode), | |||
| midout_iv(1)) { | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict src = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| for (size_t i = 0; i < tot; ++i) { | |||
| dst[i] = Kern::apply(src[i * stride]); | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| template <typename dtype, uint32_t mode> | |||
| void ElemwiseImpl::binary_kern(const ElemwiseOpParamN<2>& param) { | |||
| using ctype = typename DTypeTrait<dtype>::ctype; | |||
| using Kern = ElemwiseKern<megcorePlatformCPU, mode, ctype>; | |||
| MIDOUT_BEGIN(megdnn_fallback_elemwise_binary, ctype, midout_iv(mode)) { | |||
| if (param.max_ndim == 1) { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), | |||
| midout_iv(1)) { | |||
| auto tot = param.size; | |||
| auto as = param[0].layout.stride[0], bs = param[1].layout.stride[0]; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto dst_tensor = *m_dst; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr()); | |||
| ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| for (size_t i = 0; i < tot; ++i) { | |||
| dst[i] = Kern::apply(a[i * as], b[i * bs]); | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| if (std::min(param[0].layout.ndim, param[1].layout.ndim) > 1) { | |||
| return naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); | |||
| } | |||
| if (param.max_ndim == 2) { | |||
| if (param[0].layout.ndim == 1) { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), | |||
| midout_iv(21)) { | |||
| auto as = param[0].layout.stride[0], | |||
| bs0 = param[1].layout.stride[0], | |||
| bs1 = param[1].layout.stride[1]; | |||
| auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1]; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto dst_tensor = *m_dst; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr()); | |||
| ctype* __restrict dst = | |||
| static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| ptrdiff_t toff = 0; | |||
| for (size_t i = 0; i < n0; ++i) { | |||
| for (size_t j = 0; j < n1; ++j) { | |||
| dst[toff] = | |||
| Kern::apply(a[as * toff], b[bs0 * i + bs1 * j]); | |||
| ++toff; | |||
| } | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), | |||
| midout_iv(22)) { | |||
| megdnn_assert(param[1].layout.ndim == 1); | |||
| auto bs = param[1].layout.stride[0], as0 = param[0].layout.stride[0], | |||
| as1 = param[0].layout.stride[1]; | |||
| auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1]; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto dst_tensor = *m_dst; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr()); | |||
| ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| ptrdiff_t toff = 0; | |||
| for (size_t i = 0; i < n0; ++i) { | |||
| for (size_t j = 0; j < n1; ++j) { | |||
| dst[toff] = Kern::apply(a[as0 * i + as1 * j], b[toff * bs]); | |||
| ++toff; | |||
| } | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| if (param.max_ndim == 3) { | |||
| auto brd_101 = [](const TensorND& t) { | |||
| auto&& l = t.layout; | |||
| return l.ndim == 3 && l.stride[0] == 0 && l.stride[2] == 0; | |||
| }; | |||
| if (param[0].layout.ndim == 1 && brd_101(param[1])) { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), | |||
| midout_iv(31)) { | |||
| auto as = param[0].layout.stride[0], bs = param[1].layout.stride[1]; | |||
| auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1], | |||
| n2 = param[1].layout.shape[2]; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto dst_tensor = *m_dst; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr()); | |||
| ctype* __restrict dst = | |||
| static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| size_t toff = 0; | |||
| for (size_t i = 0; i < n0; ++i) { | |||
| for (size_t j = 0; j < n1; ++j) { | |||
| for (size_t k = 0; k < n2; ++k) { | |||
| dst[toff] = Kern::apply(a[as * toff], b[bs * j]); | |||
| ++toff; | |||
| } | |||
| } | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| if (param[1].layout.ndim == 1 && brd_101(param[0])) { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), | |||
| midout_iv(32)) { | |||
| auto as = param[0].layout.stride[1], bs = param[1].layout.stride[0]; | |||
| auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1], | |||
| n2 = param[0].layout.shape[2]; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto dst_tensor = *m_dst; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr()); | |||
| ctype* __restrict dst = | |||
| static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| size_t toff = 0; | |||
| for (size_t i = 0; i < n0; ++i) { | |||
| for (size_t j = 0; j < n1; ++j) { | |||
| for (size_t k = 0; k < n2; ++k) { | |||
| dst[toff] = Kern::apply(a[as * j], b[bs * toff]); | |||
| ++toff; | |||
| } | |||
| } | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| } | |||
| naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | |||
| if (!dst.layout.is_contiguous()) { | |||
| return naive::ElemwiseForwardImpl::exec(srcs, dst); | |||
| @@ -278,62 +82,6 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | |||
| naive::ElemwiseForwardImpl::exec(srcs, dst); | |||
| } | |||
| #define SWITCH_DTYPE(_cat, _cb) \ | |||
| switch (m_dst->layout.dtype.enumv()) { \ | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_##_cat(_cb) default \ | |||
| : megdnn_throw("bad dtype"); \ | |||
| } | |||
| template <uint32_t mode> | |||
| void ElemwiseImpl::exec_UNARY_INT() { | |||
| auto param = make_elemwise_op_param<1>(); | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| return unary_kern<_dt, mode>(param); | |||
| SWITCH_DTYPE(INT, cb) | |||
| #undef cb | |||
| } | |||
| template <uint32_t mode> | |||
| void ElemwiseImpl::exec_UNARY_FLOAT() { | |||
| auto param = make_elemwise_op_param<1>(); | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| return unary_kern<_dt, mode>(param); | |||
| SWITCH_DTYPE(FLOAT, cb) | |||
| #undef cb | |||
| } | |||
| template <uint32_t mode> | |||
| void ElemwiseImpl::exec_BINARY_INT() { | |||
| auto param = make_elemwise_op_param<2>(); | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| return binary_kern<_dt, mode>(param); | |||
| SWITCH_DTYPE(INT, cb) | |||
| #undef cb | |||
| } | |||
| template <uint32_t mode> | |||
| void ElemwiseImpl::exec_BINARY_FLOAT() { | |||
| auto param = make_elemwise_op_param<2>(); | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| return binary_kern<_dt, mode>(param); | |||
| SWITCH_DTYPE(FLOAT, cb) | |||
| #undef cb | |||
| } | |||
| #undef SWITCH_DTYPE | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,122 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/opr_unary_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "src/common/elemwise/kern_defs.cuh" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_unary) | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename dtype, uint32_t mode> | |||
| void ElemwiseImpl::unary_kern(const ElemwiseOpParamN<1>& param) { | |||
| using ctype = typename DTypeTrait<dtype>::ctype; | |||
| using Kern = ElemwiseKern<megcorePlatformCPU, mode, ctype>; | |||
| MIDOUT_BEGIN(megdnn_fallback_elemwise_unary, ctype, midout_iv(mode)) { | |||
| // only specialize for the most common 1-dim case | |||
| auto tot = param.size; | |||
| auto stride = param[0].layout.stride[0]; | |||
| auto src0 = param[0]; | |||
| auto dst_tensor = *m_dst; | |||
| if (param.max_ndim == 1) { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_elemwise_unary, ctype, midout_iv(mode), | |||
| midout_iv(1)) { | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | |||
| ctype* __restrict src = static_cast<ctype*>(src0.raw_ptr()); | |||
| ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr()); | |||
| for (size_t i = 0; i < tot; ++i) { | |||
| dst[i] = Kern::apply(src[i * stride]); | |||
| } | |||
| }); | |||
| return; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| #define SWITCH_DTYPE(_cat, _cb) \ | |||
| switch (m_dst->layout.dtype.enumv()) { \ | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_##_cat(_cb) default \ | |||
| : megdnn_throw("bad dtype"); \ | |||
| } | |||
| template <uint32_t mode> | |||
| void ElemwiseImpl::exec_UNARY_INT() { | |||
| auto param = make_elemwise_op_param<1>(); | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| return unary_kern<_dt, mode>(param); | |||
| SWITCH_DTYPE(INT, cb) | |||
| #undef cb | |||
| } | |||
| template <uint32_t mode> | |||
| void ElemwiseImpl::exec_UNARY_FLOAT() { | |||
| auto param = make_elemwise_op_param<1>(); | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| return unary_kern<_dt, mode>(param); | |||
| SWITCH_DTYPE(FLOAT, cb) | |||
| #undef cb | |||
| } | |||
| #undef SWITCH_DTYPE | |||
| using Mode = param_enumv::Elemwise::Mode; | |||
| #define INST(mode) template void megdnn::fallback::ElemwiseImpl::exec_UNARY_INT<mode>(); | |||
| INST(Mode::RELU); | |||
| INST(Mode::ABS); | |||
| INST(Mode::NEGATE); | |||
| #undef INST | |||
| #define INST(mode) \ | |||
| template void megdnn::fallback::ElemwiseImpl::exec_UNARY_FLOAT<mode>(); | |||
| INST(Mode::RELU); | |||
| INST(Mode::ABS); | |||
| INST(Mode::ACOS); | |||
| INST(Mode::ASIN); | |||
| INST(Mode::CEIL); | |||
| INST(Mode::COS); | |||
| INST(Mode::EXP); | |||
| INST(Mode::EXPM1); | |||
| INST(Mode::FLOOR); | |||
| INST(Mode::LOG); | |||
| INST(Mode::LOG1P); | |||
| INST(Mode::NEGATE); | |||
| INST(Mode::SIGMOID); | |||
| INST(Mode::SIN); | |||
| INST(Mode::TANH); | |||
| INST(Mode::FAST_TANH); | |||
| INST(Mode::ROUND); | |||
| INST(Mode::ERF); | |||
| INST(Mode::ERFINV); | |||
| INST(Mode::ERFC); | |||
| INST(Mode::ERFCINV); | |||
| INST(Mode::H_SWISH); | |||
| INST(Mode::SILU); | |||
| INST(Mode::GELU); | |||
| #undef INST | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,138 @@ | |||
| /** | |||
| * \file dnn/src/naive/elemwise_multi_type/opr_impl_1.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "megdnn/tensor_iter.h" | |||
| #include "src/common/elemwise/kern_defs.cuh" | |||
| #include "src/common/elemwise_multi_type/kern_defs.cuh" | |||
| #include "src/naive/handle.h" | |||
| using namespace megdnn; | |||
| using namespace naive; | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [src0, src1, src2, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_int16>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_int32>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<dt_int32>(src2).begin(); | |||
| auto dst_ptr = dst.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1) + (*i2); | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [src0, src1, src2, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_int16>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<dt_float32>(src2).begin(); | |||
| auto dst_ptr = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1) + (*i2); | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [src0, src1, src2, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_uint8>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<dt_float32>(src2).begin(); | |||
| auto dst_ptr = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1) + (*i2); | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto work = [src0, src1, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_int16>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto dst_ptr = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1); | |||
| ++i0; | |||
| ++i1; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| switch (param[0].layout.dtype.enumv()) { | |||
| #define cb(t) \ | |||
| case DTypeTrait<t>::enumv: \ | |||
| return dispatch_fma3_iXxf32xf32xi8<DTypeTrait<t>::ctype>(param, dst); | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("unsupported src dtype"); | |||
| } | |||
| } | |||
| template <typename ctype> | |||
| void ElemwiseMultiTypeImpl::dispatch_fma3_iXxf32xf32xi8( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [src0, src1, src2, size, dst]() { | |||
| elemwise_multi_type::Fma3iXxf32xf32xiYOp<ctype, dt_int8> op; | |||
| auto i0 = tensor_iter_valonly<ctype>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<dt_float32>(src2).begin(); | |||
| auto dst_ptr = dst.ptr<dt_int8>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = op(*i0, *i1, *i2); | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,115 @@ | |||
| /** | |||
| * \file dnn/src/naive/elemwise_multi_type/opr_impl_2.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "megdnn/tensor_iter.h" | |||
| #include "src/common/elemwise/kern_defs.cuh" | |||
| #include "src/common/elemwise_multi_type/kern_defs.cuh" | |||
| #include "src/naive/handle.h" | |||
| using namespace megdnn; | |||
| using namespace naive; | |||
| void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| switch (param[0].layout.dtype.enumv()) { | |||
| #define cb(t) \ | |||
| case DTypeTrait<t>::enumv: \ | |||
| return dispatch_round_shr_saturate_iXxi8xiX<DTypeTrait<t>::ctype, dt_int8>( \ | |||
| param, dst); | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("unsupported src dtype"); | |||
| } | |||
| } | |||
| template <typename ctype, typename dst_ctype> | |||
| void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xiX( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto size = param.size; | |||
| auto work = [src0, src1, size, dst]() { | |||
| // This is needed as these iterators are captured as const value. | |||
| auto iA = tensor_iter_valonly<ctype>(src0).begin(); | |||
| auto iB = tensor_iter_valonly<dt_int8>(src1).begin(); | |||
| auto pD = dst.ptr<dst_ctype>(); | |||
| for (size_t i = 0; i < size; i++) { | |||
| *pD = elemwise_multi_type::round_shr_saturate<ctype, dst_ctype>(*iA, *iB); | |||
| ++iA; | |||
| ++iB; | |||
| ++pD; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| template <typename ctype> | |||
| void ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_round_shr_saturate( | |||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto src3 = param[3]; | |||
| auto src4 = param[4]; | |||
| auto src5 = param[5]; | |||
| auto work = [size, src0, src1, src2, src3, src4, src5, dst]() { | |||
| auto i0 = tensor_iter_valonly<ctype>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<ctype>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<ctype>(src2).begin(); | |||
| auto ioff = tensor_iter_valonly<dt_int8>(src3).begin(); | |||
| auto imin = tensor_iter_valonly<dt_int8>(src4).begin(); | |||
| auto imax = tensor_iter_valonly<dt_int8>(src5).begin(); | |||
| auto dst_ptr = dst.ptr<dt_int8>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| auto res = elemwise_multi_type::round_shr_saturate<ctype, dt_int8>( | |||
| round_mulh_saturate<ctype>(*i0 + *i1, *i2), *ioff); | |||
| res = std::min(res, *imax); | |||
| res = std::max(res, *imin); | |||
| dst_ptr[i] = res; | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| ++ioff; | |||
| ++imin; | |||
| ++imax; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | |||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||
| dispatch_fuse_add_rmulh_round_shr_saturate<dt_int16>(param, dst); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | |||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||
| dispatch_fuse_add_rmulh_round_shr_saturate<dt_int32>(param, dst); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi16( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| switch (param[0].layout.dtype.enumv()) { | |||
| #define cb(t) \ | |||
| case DTypeTrait<t>::enumv: \ | |||
| return dispatch_round_shr_saturate_iXxi8xiX<DTypeTrait<t>::ctype, dt_int16>( \ | |||
| param, dst); | |||
| cb(::megdnn::dtype::Int32); | |||
| cb(::megdnn::dtype::Int16); | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("unsupported src dtype"); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * \file dnn/src/naive/elemwise_multi_type/opr_impl.cpp | |||
| * \file dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -18,218 +18,6 @@ | |||
| using namespace megdnn; | |||
| using namespace naive; | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [src0, src1, src2, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_int16>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_int32>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<dt_int32>(src2).begin(); | |||
| auto dst_ptr = dst.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1) + (*i2); | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [src0, src1, src2, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_int16>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<dt_float32>(src2).begin(); | |||
| auto dst_ptr = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1) + (*i2); | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [src0, src1, src2, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_uint8>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<dt_float32>(src2).begin(); | |||
| auto dst_ptr = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1) + (*i2); | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto work = [src0, src1, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_int16>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto dst_ptr = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1); | |||
| ++i0; | |||
| ++i1; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| switch (param[0].layout.dtype.enumv()) { | |||
| #define cb(t) \ | |||
| case DTypeTrait<t>::enumv: \ | |||
| return dispatch_fma3_iXxf32xf32xi8<DTypeTrait<t>::ctype>(param, dst); | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("unsupported src dtype"); | |||
| } | |||
| } | |||
| template <typename ctype> | |||
| void ElemwiseMultiTypeImpl::dispatch_fma3_iXxf32xf32xi8( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [src0, src1, src2, size, dst]() { | |||
| elemwise_multi_type::Fma3iXxf32xf32xiYOp<ctype, dt_int8> op; | |||
| auto i0 = tensor_iter_valonly<ctype>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<dt_float32>(src2).begin(); | |||
| auto dst_ptr = dst.ptr<dt_int8>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = op(*i0, *i1, *i2); | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| switch (param[0].layout.dtype.enumv()) { | |||
| #define cb(t) \ | |||
| case DTypeTrait<t>::enumv: \ | |||
| return dispatch_round_shr_saturate_iXxi8xiX<DTypeTrait<t>::ctype, dt_int8>( \ | |||
| param, dst); | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("unsupported src dtype"); | |||
| } | |||
| } | |||
| template <typename ctype, typename dst_ctype> | |||
| void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xiX( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto size = param.size; | |||
| auto work = [src0, src1, size, dst]() { | |||
| // This is needed as these iterators are captured as const value. | |||
| auto iA = tensor_iter_valonly<ctype>(src0).begin(); | |||
| auto iB = tensor_iter_valonly<dt_int8>(src1).begin(); | |||
| auto pD = dst.ptr<dst_ctype>(); | |||
| for (size_t i = 0; i < size; i++) { | |||
| *pD = elemwise_multi_type::round_shr_saturate<ctype, dst_ctype>(*iA, *iB); | |||
| ++iA; | |||
| ++iB; | |||
| ++pD; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| template <typename ctype> | |||
| void ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_round_shr_saturate( | |||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto src3 = param[3]; | |||
| auto src4 = param[4]; | |||
| auto src5 = param[5]; | |||
| auto work = [size, src0, src1, src2, src3, src4, src5, dst]() { | |||
| auto i0 = tensor_iter_valonly<ctype>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<ctype>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<ctype>(src2).begin(); | |||
| auto ioff = tensor_iter_valonly<dt_int8>(src3).begin(); | |||
| auto imin = tensor_iter_valonly<dt_int8>(src4).begin(); | |||
| auto imax = tensor_iter_valonly<dt_int8>(src5).begin(); | |||
| auto dst_ptr = dst.ptr<dt_int8>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| auto res = elemwise_multi_type::round_shr_saturate<ctype, dt_int8>( | |||
| round_mulh_saturate<ctype>(*i0 + *i1, *i2), *ioff); | |||
| res = std::min(res, *imax); | |||
| res = std::max(res, *imin); | |||
| dst_ptr[i] = res; | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| ++ioff; | |||
| ++imin; | |||
| ++imax; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | |||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||
| dispatch_fuse_add_rmulh_round_shr_saturate<dt_int16>(param, dst); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | |||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||
| dispatch_fuse_add_rmulh_round_shr_saturate<dt_int32>(param, dst); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi16( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| switch (param[0].layout.dtype.enumv()) { | |||
| #define cb(t) \ | |||
| case DTypeTrait<t>::enumv: \ | |||
| return dispatch_round_shr_saturate_iXxi8xiX<DTypeTrait<t>::ctype, dt_int16>( \ | |||
| param, dst); | |||
| cb(::megdnn::dtype::Int32); | |||
| cb(::megdnn::dtype::Int16); | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("unsupported src dtype"); | |||
| } | |||
| } | |||
| template <typename KernImpl, typename src_ctype, typename dst_ctype> | |||
| void ElemwiseMultiTypeImpl::dispatch_add_qint_op( | |||
| const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor) { | |||