/** * \file dnn/src/fallback/tile/opr_impl.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/fallback/tile/opr_impl.h" #include #include "src/naive/handle.h" #include #include "src/common/tile_repeat_helper.h" #include "src/common/utils.h" namespace megdnn { namespace fallback { size_t TileImpl::get_workspace_in_bytes(const TensorLayout &src, const TensorLayout &dst) { auto workspace_size = get_workspace_in_bytes_fwd(src, dst); return workspace_size; } void TileImpl::exec(_megdnn_tensor_in src_, _megdnn_tensor_out dst_, _megdnn_workspace workspace) { check_exec(src_.layout, dst_.layout, workspace.size); TensorShape src, dst, times; simplify_shape(src_.layout, dst_.layout, param().times, src, dst, times); auto nr_reduces = count_not_ones_in_shape(times); if (nr_reduces == 0) { MEGDNN_DISPATCH_CPU_KERN_OPR(std::memcpy(dst_.raw_ptr, src_.raw_ptr, sizeof(float) * dst.total_nr_elems())); return; } auto kern = [=]() { auto ndim = times.ndim; WorkspaceBundle workspaces(workspace.raw_ptr, {dst.total_nr_elems() * sizeof(float), dst.total_nr_elems() * sizeof(float)}); auto workspace0 = static_cast(workspaces.get(0)); auto workspace1 = static_cast(workspaces.get(1)); float *current, *next; size_t state; init_tile_repeat_state(src_.ptr(), dst_.ptr(), workspace0, workspace1, current, next, state, nr_reduces); for (size_t i = ndim; i > 0; --i) { size_t j = i-1; if (times.shape[j] != 1) { // m = sshape[0]*...*sshape[i-2] auto m = std::accumulate(src.shape, src.shape+j, 1_z, SafeMultiplies()); // n = sshape[i-1]*dshape[i]*... auto n = std::accumulate(dst.shape+i, dst.shape+ndim, 1_z, SafeMultiplies()) * src.shape[j]; // forward is repeat (m, n) to (m*times, n) tile_or_repeat_single_axis(current, next, m, n, times[j]); update_tile_repeat_state(src_.ptr(), dst_.ptr(), workspace0, workspace1, current, next, state, nr_reduces); } } }; MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); } } // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen