Browse Source

implement concat vector

tags/20171017
nihuini 8 years ago
parent
commit
0eee9aac21
1 changed files with 37 additions and 0 deletions
  1. +37
    -0
      src/layer/concat.cpp

+ 37
- 0
src/layer/concat.cpp View File

@@ -24,6 +24,43 @@ Concat::Concat()

int Concat::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs) const
{
int dims = bottom_blobs[0].dims;

if (dims == 1)
{
// concat vector
// total length
int top_w = 0;
for (size_t b=0; b<bottom_blobs.size(); b++)
{
const Mat& bottom_blob = bottom_blobs[b];
top_w += bottom_blob.w;
}

Mat& top_blob = top_blobs[0];
top_blob.create(top_w);
if (top_blob.empty())
return -100;

float* outptr = top_blob;
for (size_t b=0; b<bottom_blobs.size(); b++)
{
const Mat& bottom_blob = bottom_blobs[b];

int w = bottom_blob.w;

const float* ptr = bottom_blob;
for (int i=0; i<w; i++)
{
outptr[i] = ptr[i];
}

outptr += w;
}

return 0;
}

int w = bottom_blobs[0].w;
int h = bottom_blobs[0].h;



Loading…
Cancel
Save