|
- /**
- * \file dnn/src/common/concat_split.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
- #include "megdnn/oprs.h"
-
- #include "src/common/utils.h"
-
- #include <numeric>
-
- namespace megdnn {
-
- ConcatSplitBase::ConcatSplitBase(Handle* handle)
- : OperatorBase(handle),
- m_get_layout([](const TensorND& tensor) { return tensor.layout; }),
- m_get_shape([](const TensorLayout& layout) { return TensorShape(layout); }) {}
-
- void ConcatSplitBase::check_layout_common(
- const TensorLayoutArray& srcs, const TensorLayout& dst) {
- // ensure same data type
- for (auto&& src : srcs) {
- megdnn_assert(src.dtype == dst.dtype);
- }
- // ensure all layouts are contiguous
- for (auto&& src : srcs) {
- megdnn_assert_contiguous(src);
- }
- megdnn_assert_contiguous(dst);
- // ensure all layouts have the same ndim
- auto ndim = dst.ndim;
- for (auto&& src : srcs) {
- megdnn_assert_eq_size_t(src.ndim, ndim);
- }
- // ensure param().axis is correct
- megdnn_assert(
- param().axis < static_cast<int32_t>(ndim), "param().axis=%u, ndim=%zu",
- param().axis, ndim);
- // ensure shape size for each axis is correct
- for (size_t i = 0; i < ndim; ++i) {
- if (i == static_cast<size_t>(param().axis)) {
- size_t sum = 0_z;
- for (auto&& src : srcs)
- sum += src.shape[i];
- megdnn_assert_eq_size_t(sum, dst.shape[i]);
- } else {
- for (auto&& src : srcs) {
- megdnn_assert(src.shape[i] == dst.shape[i]);
- megdnn_assert_eq_size_t(src.shape[i], dst.shape[i]);
- }
- }
- }
- }
-
- void ConcatSplitBase::get_ABC(
- const TensorShapeArray& srcs, size_t& A, size_t* B, size_t& C) {
- auto axis = param().axis;
- auto shape_arr = srcs[0].shape;
- auto ndim = srcs[0].ndim;
- A = std::accumulate(shape_arr, shape_arr + axis, 1_z, SafeMultiplies<size_t>());
- for (size_t i = 0u; i < srcs.size(); ++i) {
- B[i] = srcs[i].shape[axis];
- }
- C = std::accumulate(
- shape_arr + (axis + 1), shape_arr + ndim, 1_z, SafeMultiplies<size_t>());
- }
-
- void ConcatForward::deduce_layout(const TensorLayoutArray& srcs, TensorLayout& dst) {
- dst = srcs[0];
- auto i = param().axis;
- dst.shape[i] = 0u;
- for (auto&& src : srcs) {
- dst.shape[i] += src.shape[i];
- }
- dst.init_contiguous_stride();
- }
-
- void ConcatForward::check_exec(
- const TensorLayoutArray& srcs, const TensorLayout& dst,
- size_t workspace_in_bytes) {
- check_layout_common(srcs, dst);
- auto required_workspace_in_bytes = get_workspace_in_bytes(srcs, dst);
- megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
- }
-
- void SplitForward::check_exec(
- const TensorLayout& src, const TensorLayoutArray& dsts,
- size_t workspace_in_bytes) {
- check_layout_common(dsts, src);
- auto required_workspace_in_bytes = get_workspace_in_bytes(src, dsts);
- megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
- }
-
- } // namespace megdnn
- // vim: syntax=cpp.doxygen
|