Browse Source

Support torch.cumsum (#4505)

tags/20230223
Fangjun Kuang GitHub 3 years ago
parent
commit
92e75105c9
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 539 additions and 0 deletions
  1. +15
    -0
      docs/developer-guide/operators.md
  2. +1
    -0
      src/CMakeLists.txt
  3. +171
    -0
      src/layer/cumulativesum.cpp
  4. +37
    -0
      src/layer/cumulativesum.h
  5. +1
    -0
      tests/CMakeLists.txt
  6. +70
    -0
      tests/test_cumulativesum.cpp
  7. +2
    -0
      tools/pnnx/src/CMakeLists.txt
  8. +43
    -0
      tools/pnnx/src/pass_level2/torch_cumsum.cpp
  9. +57
    -0
      tools/pnnx/src/pass_ncnn/torch_cumsum.cpp
  10. +1
    -0
      tools/pnnx/tests/CMakeLists.txt
  11. +1
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  12. +70
    -0
      tools/pnnx/tests/ncnn/test_torch_cumsum.py
  13. +70
    -0
      tools/pnnx/tests/test_torch_cumsum.py

+ 15
- 0
docs/developer-guide/operators.md View File

@@ -15,6 +15,7 @@
* [ConvolutionDepthWise1D](#convolutiondepthwise1d)
* [ConvolutionDepthWise3D](#convolutiondepthwise3d)
* [Crop](#crop)
* [CumulativeSum](#cumulativesum)
* [Deconvolution](#deconvolution)
* [Deconvolution1D](#deconvolution1d)
* [Deconvolution3D](#deconvolution3d)
@@ -449,6 +450,20 @@ y = crop(x)
| 10 | ends | array | [ ] | |
| 11 | axes | array | [ ] | |

# CumulativeSum

If axis < 0, we use axis = x.dims + axis

It implements https://pytorch.org/docs/stable/generated/torch.cumsum.html

* one_blob_only
* support_inplace

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | axis | int | 0 | |


# Deconvolution
```
x2 = deconv(x, weight, kernel, stride, dilation) + bias


+ 1
- 0
src/CMakeLists.txt View File

@@ -160,6 +160,7 @@ ncnn_add_layer(GLU)
ncnn_add_layer(Fold)
ncnn_add_layer(Unfold)
ncnn_add_layer(GridSample)
ncnn_add_layer(CumulativeSum)

if(NCNN_VULKAN)
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)


+ 171
- 0
src/layer/cumulativesum.cpp View File

@@ -0,0 +1,171 @@
// Copyright (c) 2023 Xiaomi Corp. (author: Fangjun Kuang)
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this
// file except in compliance with the License. You may obtain a copy of the
// License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.

#include "cumulativesum.h"

namespace ncnn {

CumulativeSum::CumulativeSum()
{
one_blob_only = true;
support_inplace = true;
}

int CumulativeSum::load_param(const ParamDict& pd)
{
axis = pd.get(0, 0);

return 0;
}

int CumulativeSum::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
int dims = bottom_top_blob.dims;
int positive_axis = axis < 0 ? dims + axis : axis;

if (dims == 1)
{ // ignore axis
int w = bottom_top_blob.w;

float* ptr = bottom_top_blob;

for (int i = 1; i < w; ++i)
{
ptr[i] = ptr[i] + ptr[i - 1];
}

return 0;
} // if (dims == 1)

if (dims == 2 && positive_axis == 0)
{
// sum over rows
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;

for (int i = 1; i < h; ++i)
{
const float* prev_row = bottom_top_blob.row(i - 1);
float* this_row = bottom_top_blob.row(i);

for (int k = 0; k < w; ++k)
{
this_row[k] = this_row[k] + prev_row[k];
}
}

return 0;
} // if (dims == 2 && positive_axis == 0)

if (dims == 2 && positive_axis == 1)
{
// sum over columns
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; ++i)
{
float* ptr = bottom_top_blob.row(i);

for (int k = 1; k < w; ++k)
{
ptr[k] = ptr[k] + ptr[k - 1];
}
}

return 0;
} // if (dims == 2 && positive_axis == 1)

if (dims == 3 && positive_axis == 0)
{
// sum over channels
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int c = bottom_top_blob.c;

int size = w * h;

for (int i = 1; i < c; ++i)
{
const float* prev = bottom_top_blob.channel(i - 1);
float* cur = bottom_top_blob.channel(i);

for (int k = 0; k < size; ++k)
{
cur[k] = cur[k] + prev[k];
}
}

return 0;
} // if (dims == 3 && positive_axis == 0)

if (dims == 3 && positive_axis == 1)
{
// sum over rows within each channel

int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int c = bottom_top_blob.c;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < c; ++q)
{
Mat this_channel = bottom_top_blob.channel(q);

for (int i = 1; i < h; ++i)
{
const float* prev_row = this_channel.row(i - 1);
float* this_row = this_channel.row(i);

for (int k = 0; k < w; ++k)
{
this_row[k] = this_row[k] + prev_row[k];
}
}
}

return 0;
} // if (dims == 3 && positive_axis == 1)

if (dims == 3 && positive_axis == 2)
{
// sum over columns within each channel

int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int c = bottom_top_blob.c;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < c; ++q)
{
Mat this_channel = bottom_top_blob.channel(q);

for (int i = 0; i < h; ++i)
{
float* ptr = this_channel.row(i);
for (int k = 1; k < w; ++k)
{
ptr[k] = ptr[k] + ptr[k - 1];
}
}
}

return 0;
} // if (dims == 3 && positive_axis == 2)

return -100;
}

} // namespace ncnn

+ 37
- 0
src/layer/cumulativesum.h View File

@@ -0,0 +1,37 @@
// Copyright (c) 2023 Xiaomi Corp. (author: Fangjun Kuang)
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this
// file except in compliance with the License. You may obtain a copy of the
// License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.

#ifndef LAYER_CUMULATIVESUM_H
#define LAYER_CUMULATIVESUM_H

#include "layer.h"

namespace ncnn {

class CumulativeSum : public Layer
{
public:
CumulativeSum();

virtual int load_param(const ParamDict& pd);

virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const;

public:
int axis;
};

} // namespace ncnn

#endif // LAYER_CUMULATIVESUM_H

+ 1
- 0
tests/CMakeLists.txt View File

@@ -81,6 +81,7 @@ ncnn_add_layer_test(ConvolutionDepthWise)
ncnn_add_layer_test(ConvolutionDepthWise1D)
ncnn_add_layer_test(ConvolutionDepthWise3D)
ncnn_add_layer_test(Crop)
ncnn_add_layer_test(CumulativeSum)
ncnn_add_layer_test(Deconvolution)
ncnn_add_layer_test(Deconvolution1D)
ncnn_add_layer_test(Deconvolution3D)


+ 70
- 0
tests/test_cumulativesum.cpp View File

@@ -0,0 +1,70 @@
// Copyright (c) 2023 Xiaomi Corp. (author: Fangjun Kuang)
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "layer/cumulativesum.h"
#include "testutil.h"

static int test_cumulativesum(const ncnn::Mat& a, int axis)
{
ncnn::ParamDict pd;
pd.set(0, axis);

std::vector<ncnn::Mat> weights(0);

int ret = test_layer<ncnn::CumulativeSum>("CumulativeSum", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_cumulativesum failed a.dims=%d a=(%d %d %d) axis=%d\n", a.dims, a.w, a.h, a.c, axis);
}

return ret;
}

static int test_cumulativesum_1d()
{
return 0
|| test_cumulativesum(RandomMat(6), 0)
|| test_cumulativesum(RandomMat(10), 0)
|| test_cumulativesum(RandomMat(10), -1)
|| test_cumulativesum(RandomMat(10), -2)
|| test_cumulativesum(RandomMat(101), 0);
}

static int test_cumulativesum_2d()
{
return 0
|| test_cumulativesum(RandomMat(6, 8), 0)
|| test_cumulativesum(RandomMat(20, 103), 1)
|| test_cumulativesum(RandomMat(106, 50), -1)
|| test_cumulativesum(RandomMat(106, 50), -2);
}

static int test_cumulativesum_3d()
{
return 0
|| test_cumulativesum(RandomMat(10, 6, 8), 0)
|| test_cumulativesum(RandomMat(303, 20, 103), 1)
|| test_cumulativesum(RandomMat(106, 50, 99), 2)
|| test_cumulativesum(RandomMat(303, 200, 103), -1)
|| test_cumulativesum(RandomMat(303, 200, 103), -2)
|| test_cumulativesum(RandomMat(303, 200, 103), -2);
}

int main()
{
SRAND(7767517);

return 0
|| test_cumulativesum_1d()
|| test_cumulativesum_2d()
|| test_cumulativesum_3d();
}

+ 2
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -206,6 +206,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_clamp.cpp
pass_level2/torch_clone.cpp
pass_level2/torch_complex.cpp
pass_level2/torch_cumsum.cpp
pass_level2/torch_dequantize.cpp
pass_level2/torch_einsum.cpp
pass_level2/torch_empty.cpp
@@ -497,6 +498,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/torch_bmm.cpp
pass_ncnn/torch_clamp.cpp
pass_ncnn/torch_clone.cpp
pass_ncnn/torch_cumsum.cpp
pass_ncnn/torch_flatten.cpp
pass_ncnn/torch_logsumexp.cpp
pass_ncnn/torch_matmul.cpp


+ 43
- 0
tools/pnnx/src/pass_level2/torch_cumsum.cpp View File

@@ -0,0 +1,43 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
// 2023 Xiaomi Corp. (author: Fangjun Kuang)
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_cumsum : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 5
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 dim
prim::Constant op_1 0 1 dtype value=*
aten::cumsum op_2 3 1 input dim dtype out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.cumsum";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_cumsum, 20)

} // namespace pnnx

+ 57
- 0
tools/pnnx/src/pass_ncnn/torch_cumsum.cpp View File

@@ -0,0 +1,57 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
// 2023 Xiaomi Corp. (author: Fangjun Kuang)
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class torch_cumsum : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.cumsum op_0 1 1 input out dim=%dim
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "CumulativeSum";
}

const char* name_str() const
{
return "cumsum";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const int dim = captured_params.at("dim").i;

op->params["0"] = dim;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_cumsum, 20)

} // namespace ncnn

} // namespace pnnx

+ 1
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -188,6 +188,7 @@ pnnx_add_test(torch_cat)
pnnx_add_test(torch_chunk)
pnnx_add_test(torch_clone)
pnnx_add_test(torch_complex)
pnnx_add_test(torch_cumsum)
pnnx_add_test(torch_einsum)
pnnx_add_test(torch_eq)
pnnx_add_test(torch_flatten)


+ 1
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -140,6 +140,7 @@ pnnx_ncnn_add_test(torch_bmm)
pnnx_ncnn_add_test(torch_cat)
pnnx_ncnn_add_test(torch_chunk)
pnnx_ncnn_add_test(torch_clone)
pnnx_ncnn_add_test(torch_cumsum)
pnnx_ncnn_add_test(torch_einsum)
pnnx_ncnn_add_test(torch_logsumexp)
pnnx_ncnn_add_test(torch_matmul)


+ 70
- 0
tools/pnnx/tests/ncnn/test_torch_cumsum.py View File

@@ -0,0 +1,70 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
# 2023 Xiaomi Corp. (author: Fangjun Kuang)
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
# x - 3d
# y - 2d
# z - 1d
x0 = torch.cumsum(x, dim=0)
x1 = torch.cumsum(x, dim=1)
x2 = torch.cumsum(x, dim=2)

y0 = torch.cumsum(y, dim=0)
y1 = torch.cumsum(y, dim=1)

z0 = torch.cumsum(z, dim=0)
return x0, x1, x2, y0, y1, z0

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(2, 3, 16)
y = torch.rand(5, 9)
z = torch.rand(3)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_cumsum.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_cumsum.pt inputshape=[2,3,16],[5,9],[3]")

# ncnn inference
import test_torch_cumsum_ncnn
b = test_torch_cumsum_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 70
- 0
tools/pnnx/tests/test_torch_cumsum.py View File

@@ -0,0 +1,70 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
# 2023 Xiaomi Corp. (author: Fangjun Kuang)
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
# x - 3d
# y - 2d
# z - 1d
x0 = torch.cumsum(x, dim=0)
x1 = torch.cumsum(x, dim=1)
x2 = torch.cumsum(x, dim=2)

y0 = torch.cumsum(y, dim=0)
y1 = torch.cumsum(y, dim=1)

z0 = torch.cumsum(z, dim=0)
return x0, x1, x2, y0, y1, z0

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(2, 3, 16)
y = torch.rand(5, 9)
z = torch.rand(14)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_cumsum.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_cumsum.pt inputshape=[2,3,16],[5,9],[14]")

# pnnx inference
import test_torch_cumsum_pnnx
b = test_torch_cumsum_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save