|
- /**
- * \file dnn/src/common/rnn_cell.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
- #include "src/common/rnn_cell.h"
- #include "megdnn/oprs.h"
- #include "src/common/utils.h"
-
- namespace megdnn {
-
- void RNNCell::deduce_layout(
- const TensorLayout& input, const TensorLayout& weight_ih,
- const TensorLayout& bias_ih, const TensorLayout& hx,
- const TensorLayout& weight_hh, const TensorLayout& bias_hh, TensorLayout& dst) {
- // megdnn_assert(hx.ndim == 2);
- size_t batch_size = hx.shape[0];
- // size_t hidden_size = weight_hh.shape[1];
- size_t gate_hidden_size = weight_ih.shape[0];
- // size_t input_size = weight_ih.shape[1];
- // megdnn_assert(input.shape[1] == input_size);
- // megdnn_assert(hx.shape[1] == hidden_size);
- // megdnn_assert_eq_dtype(input, hx);
-
- dst = TensorLayout(TensorShape({batch_size, gate_hidden_size}), input.dtype);
- }
-
- void RNNCell::check_exec(
- const TensorLayout& input, const TensorLayout& weight_ih,
- const TensorLayout& bias_ih, const TensorLayout& hx,
- const TensorLayout& weight_hh, const TensorLayout& bias_hh,
- const TensorLayout& dst, size_t workspace_in_bytes) {
- TensorLayout dst_expected;
- megdnn_assert_eq_dtype(input, dst);
- megdnn_assert_eq_dtype(hx, dst);
- deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst_expected);
- megdnn_assert_eq_layout(dst_expected, dst);
-
- auto required_workspace_in_bytes = get_workspace_in_bytes(
- input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst);
- megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
- }
-
- } // namespace megdnn
-
- namespace megdnn {
- namespace rnn_cell {
-
- size_t get_workspace_in_bytes(
- const TensorLayout& input, const TensorLayout& weight_ih,
- const TensorLayout& bias_ih, const TensorLayout& hx,
- const TensorLayout& weight_hh, const TensorLayout& bias_hh,
- const TensorLayout& dst, Handle* handle) {
- auto opr = handle->create_operator<MatrixMulForward>();
- opr->param().transposeB = true;
- return dst.span().dist_byte() + opr->get_workspace_in_bytes(hx, weight_hh, dst);
- }
-
- void exec(
- _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
- _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
- _megdnn_tensor_out dst, _megdnn_workspace workspace,
- param::RNNCell::NonlineMode nonline_mode, Handle* handle) {
- TensorND tmp{static_cast<void*>(workspace.raw_ptr), dst.layout};
- _megdnn_workspace new_workspace = {
- workspace.raw_ptr + dst.layout.span().dist_byte(),
- workspace.size - dst.layout.span().dist_byte()};
- auto opr = handle->create_operator<MatrixMulForward>();
- opr->param().transposeB = true;
- opr->exec(input, weight_ih, tmp, new_workspace);
- opr->exec(hx, weight_hh, dst, new_workspace);
- // if (this->param().bias) add_bias(dst, tmp, bias, dst);
- // if (this->param().bias) {
- auto add_opr = handle->create_operator<ElemwiseForward>();
- add_opr->param().mode = Elemwise::Param::Mode::ADD;
- add_opr->exec({dst, tmp}, dst);
- add_opr->exec({dst, bias_ih}, dst);
- add_opr->exec({dst, bias_hh}, dst);
- // }
-
- // activation
- using NonlineMode = param::RNNCell::NonlineMode;
-
- switch (nonline_mode) {
- #define cb(_mode) \
- case NonlineMode::_mode: { \
- auto nonlinear = handle->create_operator<ElemwiseForward>(); \
- nonlinear->param().mode = Elemwise::Param::Mode::_mode; \
- nonlinear->exec({dst}, dst); \
- break; \
- }
- cb(RELU);
- cb(TANH);
- #undef cb
- case NonlineMode::IDENTITY:
- break;
- default:
- megdnn_assert(false);
- }
- }
-
- } // namespace rnn_cell
- } // namespace megdnn
|