| @@ -177,7 +177,7 @@ int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Optio | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -889,7 +889,7 @@ int InnerProduct_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const | |||
| { | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -1295,7 +1295,7 @@ int InnerProduct_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, co | |||
| quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q); | |||
| } | |||
| if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1) | |||
| if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input) | |||
| { | |||
| // gemm | |||
| Mat bottom_blob_int8_unpacked; | |||
| @@ -53,7 +53,7 @@ int InnerProduct_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, cons | |||
| { | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -53,7 +53,7 @@ int InnerProduct_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const | |||
| { | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -115,7 +115,7 @@ int InnerProduct::forward(const Mat& bottom_blob, Mat& top_blob, const Option& o | |||
| size_t elemsize = bottom_blob.elemsize; | |||
| int size = w * h; | |||
| if (bottom_blob.dims == 2 && w == num_input && h > 1) | |||
| if (bottom_blob.dims == 2 && w == num_input) | |||
| { | |||
| // gemm | |||
| top_blob.create(num_output, h, elemsize, opt.blob_allocator); | |||
| @@ -201,7 +201,7 @@ int InnerProduct::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Opti | |||
| quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_g); | |||
| } | |||
| if (bottom_blob.dims == 2 && w == num_input && h > 1) | |||
| if (bottom_blob.dims == 2 && w == num_input) | |||
| { | |||
| // gemm | |||
| top_blob.create(num_output, h, 4u, opt.blob_allocator); | |||
| @@ -137,7 +137,7 @@ int InnerProduct_loongarch::forward(const Mat& bottom_blob, Mat& top_blob, const | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -667,7 +667,7 @@ int InnerProduct_loongarch::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, | |||
| { | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -1168,7 +1168,7 @@ int InnerProduct_loongarch::forward_int8_loongarch(const Mat& bottom_blob, Mat& | |||
| quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q); | |||
| } | |||
| if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1) | |||
| if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input) | |||
| { | |||
| // gemm | |||
| Mat bottom_blob_int8_unpacked; | |||
| @@ -137,7 +137,7 @@ int InnerProduct_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Opti | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -667,7 +667,7 @@ int InnerProduct_mips::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, cons | |||
| { | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -1168,7 +1168,7 @@ int InnerProduct_mips::forward_int8_mips(const Mat& bottom_blob, Mat& top_blob, | |||
| quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q); | |||
| } | |||
| if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1) | |||
| if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input) | |||
| { | |||
| // gemm | |||
| Mat bottom_blob_int8_unpacked; | |||
| @@ -173,7 +173,7 @@ int InnerProduct_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Opt | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -577,7 +577,7 @@ int InnerProduct_riscv::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, con | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -839,7 +839,7 @@ int InnerProduct_riscv::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, co | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -79,7 +79,7 @@ int InnerProduct_vulkan::create_pipeline(const Option& _opt) | |||
| convert_packing(bias_data, bias_data_packed, out_elempack, opt); | |||
| } | |||
| if (shape.dims == 2 && shape.w == num_input && shape.h > 1) | |||
| if (shape.dims == 2 && shape.w == num_input) | |||
| { | |||
| // gemm | |||
| int elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1; | |||
| @@ -427,7 +427,7 @@ int InnerProduct_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCo | |||
| int in_elempack = opt.use_shader_pack8 && num_input % 8 == 0 ? 8 : num_input % 4 == 0 ? 4 : 1; | |||
| int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -587,7 +587,7 @@ int InnerProduct_vulkan::forward(const VkImageMat& bottom_blob, VkImageMat& top_ | |||
| int in_elempack = opt.use_shader_pack8 && num_input % 8 == 0 ? 8 : num_input % 4 == 0 ? 4 : 1; | |||
| int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -118,7 +118,7 @@ int InnerProduct_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Optio | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -190,7 +190,7 @@ int InnerProduct_x86::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const | |||
| { | |||
| const int num_input = weight_data_size / num_output; | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input && bottom_blob.h * bottom_blob.elempack > 1) | |||
| if (bottom_blob.dims == 2 && bottom_blob.w == num_input) | |||
| { | |||
| // gemm | |||
| int h = bottom_blob.h; | |||
| @@ -309,7 +309,7 @@ int InnerProduct_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, co | |||
| quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q); | |||
| } | |||
| if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input && bottom_blob_int8.h * bottom_blob_int8.elempack > 1) | |||
| if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input) | |||
| { | |||
| // gemm | |||
| Mat bottom_blob_int8_unpacked; | |||
| @@ -179,6 +179,8 @@ static int test_innerproduct_gemm(const ncnn::Mat& a, int outch, int bias) | |||
| static int test_innerproduct_4() | |||
| { | |||
| return 0 | |||
| || test_innerproduct_gemm(RandomMat(1, 1), 1, 1) | |||
| || test_innerproduct_gemm(RandomMat(48, 1), 11, 1) | |||
| || test_innerproduct_gemm(RandomMat(1, 5), 1, 1) | |||
| || test_innerproduct_gemm(RandomMat(3, 2), 2, 0) | |||
| || test_innerproduct_gemm(RandomMat(9, 8), 7, 1) | |||
| @@ -23,7 +23,7 @@ class Model(nn.Module): | |||
| self.linear_0 = nn.Linear(in_features=64, out_features=16, bias=False) | |||
| self.linear_1 = nn.Linear(in_features=16, out_features=3, bias=True) | |||
| def forward(self, x, y, z): | |||
| def forward(self, x, y, z, w): | |||
| x = self.linear_0(x) | |||
| x = self.linear_1(x) | |||
| @@ -33,7 +33,10 @@ class Model(nn.Module): | |||
| z = self.linear_0(z) | |||
| z = self.linear_1(z) | |||
| z = F.relu(z) | |||
| return x, y, z | |||
| w = self.linear_0(w) | |||
| w = self.linear_1(w) | |||
| return x, y, z, w | |||
| def test(): | |||
| net = Model().half().float() | |||
| @@ -43,22 +46,26 @@ def test(): | |||
| x = torch.rand(64) | |||
| y = torch.rand(12, 64) | |||
| z = torch.rand(1, 3, 12, 64) | |||
| w = torch.rand(1, 64) | |||
| a0, a1, a2 = net(x, y, z) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_nn_Linear.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_nn_Linear.pt inputshape=[64],[12,64],[1,3,12,64]") | |||
| os.system("../../src/pnnx test_nn_Linear.pt inputshape=[64],[12,64],[1,3,12,64],[1,64]") | |||
| # ncnn inference | |||
| import test_nn_Linear_ncnn | |||
| b0, b1, b2 = test_nn_Linear_ncnn.test_inference() | |||
| b = test_nn_Linear_ncnn.test_inference() | |||
| return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) and torch.allclose(a2, b2, 1e-4, 1e-4) | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||