#include "src/common/rnn_cell.h" #include "megdnn/oprs.h" #include "src/common/utils.h" namespace megdnn { void RNNCell::deduce_layout( const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout& /*bias_ih*/, const TensorLayout& hx, const TensorLayout& /*weight_hh*/, const TensorLayout& /*bias_hh*/, TensorLayout& dst) { size_t batch_size = hx.shape[0]; size_t gate_hidden_size = weight_ih.shape[0]; dst = TensorLayout(TensorShape({batch_size, gate_hidden_size}), input.dtype); } void RNNCell::check_exec( const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout& bias_ih, const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout& bias_hh, const TensorLayout& dst, size_t workspace_in_bytes) { TensorLayout dst_expected; auto errmsg = [&]() { std::string msg; msg.append("input="); msg.append(input.to_string()); msg.append(", weight_ih="); msg.append(weight_ih.to_string()); msg.append(", bias_ih="); msg.append(bias_ih.to_string()); msg.append(", hx="); msg.append(hx.to_string()); msg.append(", weight_hh="); msg.append(weight_hh.to_string()); msg.append(", bias_hh="); msg.append(bias_hh.to_string()); msg.append(", dst="); msg.append(dst.to_string()); return msg; }; #define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str()); ASSERT_BRIEF(input.ndim == 2) ASSERT_BRIEF(hx.ndim == 2) ASSERT_BRIEF(hx.shape[0] == input.shape[0]) // batch ASSERT_BRIEF(input.shape[1] == weight_ih.shape[1]) ASSERT_BRIEF(hx.shape[0] == dst.shape[0]) // batch ASSERT_BRIEF(hx.shape[1] == dst.shape[1]) ASSERT_BRIEF(hx.shape[1] == weight_ih.shape[0]) // hidden_size ASSERT_BRIEF(weight_ih.shape[0] == weight_hh.shape[0]) ASSERT_BRIEF(weight_hh.shape[0] == weight_hh.shape[1]) ASSERT_BRIEF(bias_ih.shape[0] == bias_hh.shape[0]) #undef ASSERT_BRIEF megdnn_assert_eq_dtype(input, dst); megdnn_assert_eq_dtype(hx, dst); deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); auto required_workspace_in_bytes = get_workspace_in_bytes( input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } } // namespace megdnn namespace megdnn { namespace rnn_cell { size_t get_workspace_in_bytes( const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout& /*bias_ih*/, const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout& /*bias_hh*/, const TensorLayout& dst, Handle* handle) { auto opr = handle->create_operator(); opr->param().transposeB = true; return dst.span().dist_byte() + std::max( opr->get_workspace_in_bytes(hx, weight_hh, dst), opr->get_workspace_in_bytes(input, weight_ih, dst)); } void exec( _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, _megdnn_tensor_out dst, _megdnn_workspace workspace, param::RNNCell::NonlineMode nonline_mode, Handle* handle) { TensorND tmp{static_cast(workspace.raw_ptr), dst.layout}; _megdnn_workspace new_workspace = { workspace.raw_ptr + dst.layout.span().dist_byte(), workspace.size - dst.layout.span().dist_byte()}; auto opr = handle->create_operator(); opr->param().transposeB = true; opr->exec(input, weight_ih, tmp, new_workspace); opr->exec(hx, weight_hh, dst, new_workspace); auto add_opr = handle->create_operator(); add_opr->param().mode = Elemwise::Param::Mode::ADD; add_opr->exec({dst, tmp}, dst); add_opr->exec({dst, bias_ih}, dst); add_opr->exec({dst, bias_hh}, dst); // activation using NonlineMode = param::RNNCell::NonlineMode; switch (nonline_mode) { #define cb(_mode) \ case NonlineMode::_mode: { \ auto nonlinear = handle->create_operator(); \ nonlinear->param().mode = Elemwise::Param::Mode::_mode; \ nonlinear->exec({dst}, dst); \ break; \ } cb(RELU); cb(TANH); #undef cb case NonlineMode::IDENTITY: break; default: megdnn_assert(false); } } } // namespace rnn_cell } // namespace megdnn