/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "nnacl/batch_to_space.h" #include "nnacl/arithmetic_common.h" void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, int data_size) { int block_h = block[0]; int block_w = block[1]; int in_h = in_shape[1]; int in_w = in_shape[2]; int in_c = in_shape[3]; size_t stride_h = block_w * out_n; size_t output_offset = 0; size_t copy_size = in_c * data_size; size_t in_stride_h = in_w * in_c; size_t in_stride_n = in_stride_h * in_h; for (int n = 0; n < out_n; ++n) { for (int h = 0; h < in_h; ++h) { size_t h_offset = h * in_stride_h; for (int bh = 0; bh < block_h; ++bh) { for (int w = 0; w < in_w; ++w) { size_t w_offset = w * in_c; for (int bw = 0; bw < block_w; ++bw) { size_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; memcpy((int8_t *)output + output_offset, (int8_t *)input + in_offset * data_size, copy_size); output_offset += copy_size; } } } } } } void BatchToSpaceForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, const int *crops, int data_size) { int block_h = block[0]; int block_w = block[1]; int in_h = in_shape[1]; int in_w = in_shape[2]; int in_c = in_shape[3]; int h_start = crops[0] / block_h; int h_valid_begin = crops[0]; int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h); int h_valid_end = in_h * block_h - crops[1] - 1; int w_start = crops[2] / block_w; int w_valid_begin = crops[2]; int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w); int w_valid_end = in_w * block_w - crops[3] - 1; size_t stride_h = block_w * out_n; size_t output_offset = 0; size_t copy_size = in_c * data_size; size_t in_stride_h = in_w * in_c; size_t in_stride_n = in_stride_h * in_h; for (int n = 0; n < out_n; ++n) { for (int h = h_start; h < h_end; ++h) { size_t h_offset = h * in_stride_h; for (int bh = 0; bh < block_h; ++bh) { size_t h_index = h * block_h + bh; if (h_index < h_valid_begin || h_index > h_valid_end) { continue; } for (int w = w_start; w < w_end; ++w) { size_t w_offset = w * in_c; for (int bw = 0; bw < block_w; ++bw) { size_t w_index = w * block_w + bw; if (w_index < w_valid_begin || w_index > w_valid_end) { continue; } size_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; memcpy((int8_t *)output + output_offset, (int8_t *)input + in_offset * data_size, copy_size); output_offset += copy_size; } } } } } }