/** * \file src/opr/impl/dnn/rnn.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/opr/dnn/rnn.h" #include "../internal/megdnn_opr_wrapper.inl" #include "megbrain/graph/grad_impl.h" #include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/blas.h" #include "megbrain/opr/internal/out_shape_by_sym_var.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" using namespace mgb; using namespace opr; /* ================= RNNCell ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNNCellForward); RNNCellForward::RNNCellForward( VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, VarNode* weight_hh, VarNode* bias_hh, const Param& param, const OperatorNodeConfig& config) : Super{input->owner_graph(), config, "rnn_cell", {input, weight_ih, bias_ih, hx, weight_hh, bias_hh}} { init_megdnn_opr(*this, param); add_input({input, weight_ih, bias_ih, hx, weight_hh, bias_hh}); } SymbolVar RNNCellForward::make( SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, SymbolVar weight_hh, SymbolVar bias_hh, const Param& param, const OperatorNodeConfig& config) { return input.insert_single_output_opr( input.node(), weight_ih.node(), bias_ih.node(), hx.node(), weight_hh.node(), bias_hh.node(), param, config); } #if MGB_ENABLE_GRAD VarNode* rnnCellBackward( const SymbolVar& input, const SymbolVar& weight_ih, const SymbolVar& hx, const SymbolVar& weight_hh, const SymbolVar& out, RNNCell::NonlineMode nonlineMode, size_t wrt_idx, const SymbolVar& og) { SymbolVar tmp; // activation using NonlineMode = RNNCell::NonlineMode; using Mode = Elemwise::Mode; switch (nonlineMode) { case NonlineMode::IDENTITY: tmp = og; break; case NonlineMode::TANH: tmp = Elemwise::make({out, og}, Mode::TANH_GRAD); break; case NonlineMode::RELU: tmp = Elemwise::make({out, og}, Mode::SWITCH_GT0); break; default: mgb_throw(GraphError, "Activation method not supported"); } // now grad is in tmp if (wrt_idx == 2 || wrt_idx == 5) return tmp.node(); // bias SymbolVar result; // A * Bt = C, A' = C' * B, B' = C't * A if (wrt_idx == 0) { // input result = MatrixMul::make( tmp, weight_ih, {false, false}); // transpose a false, transpose b false } else if (wrt_idx == 1) { // weight_ih result = MatrixMul::make(tmp, input, {true, false}); } else if (wrt_idx == 3) { // hx result = MatrixMul::make(tmp, weight_hh, {false, false}); } else if (wrt_idx == 4) { // weight_hh result = MatrixMul::make(tmp, hx, {true, false}); } return result.node(); } MGB_IMPL_OPR_GRAD(RNNCell) { SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)), weight_hh(opr.input(4)); SymbolVar out(opr.output(0)), og{out_grad.at(0)}; return rnnCellBackward( input, weight_ih, hx, weight_hh, out, opr.param().nonlineMode, wrt_idx, og); } #endif /* ================= LSTMCell ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTMCell); LSTMCellForward::LSTMCellForward( VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, VarNode* weight_hh, VarNode* bias_hh, VarNode* cx, const Param& param, const OperatorNodeConfig& config) : Super{input->owner_graph(), config, "lstm_cell", {input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx}} { init_megdnn_opr(*this, param); add_input({input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx}); } SymbolVar LSTMCellForward::make( SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, SymbolVar weight_hh, SymbolVar bias_hh, SymbolVar cx, const Param& param, const OperatorNodeConfig& config) { return input.insert_single_output_opr( input.node(), weight_ih.node(), bias_ih.node(), hx.node(), weight_hh.node(), bias_hh.node(), cx.node(), param, config); } #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LSTMCell) { SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)), weight_hh(opr.input(4)), cx(opr.input(6)); SymbolVar h_out(opr.output(0)), c_out(opr.output(1)), gates(opr.output(2)), h_og{out_grad.at(0)}, c_og{out_grad.at(1)}, tmp; size_t ghs = gates.shape()[1] / 4; // gate_hidden_size SymbolVarArray gates_array = Split::make( gates, Split::Options::make_partition(gates, 1, {ghs, ghs, ghs, ghs})); mgb_assert(gates_array.size() == 4); using Mode = Elemwise::Mode; const SymbolVar &i(Elemwise::make({gates_array.at(0)}, Mode::SIGMOID)), f(Elemwise::make({gates_array.at(1)}, Mode::SIGMOID)), o(Elemwise::make({gates_array.at(2)}, Mode::SIGMOID)), g(Elemwise::make({gates_array.at(3)}, Mode::TANH)); SymbolVar i_grad, f_grad, o_grad, g_grad; SymbolVar tanh_c_out = Elemwise::make({c_out}, Mode::TANH); o_grad = Elemwise::make({o, h_og * tanh_c_out}, Mode::SIGMOID_GRAD); c_og = c_og + Elemwise::make({tanh_c_out, h_og * o}, Mode::TANH_GRAD); f_grad = Elemwise::make({f, c_og * cx}, Mode::SIGMOID_GRAD); i_grad = Elemwise::make({i, c_og * g}, Mode::SIGMOID_GRAD); g_grad = Elemwise::make({g, c_og * i}, Mode::TANH_GRAD); SymbolVar rnn_cell_grad = Concat::make({i_grad, f_grad, o_grad, g_grad}, {-1}); SymbolVar result; if (wrt_idx < 6) { using NonlineMode = RNNCell::NonlineMode; result = rnnCellBackward( input, weight_ih, hx, weight_hh, gates, NonlineMode::IDENTITY, wrt_idx, rnn_cell_grad); } else { // cx result = c_og * f; } return result.node(); } #endif /* ================= RNN ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNN); MEGDNN_OPR_INIT3(RNNForward, "rnn_fwd"); #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(RNN) { mgb_assert( opr.param().fwd_mode == RNN::Param::FwdMode::TRAINING, "RNN could only take grad in training mode"); SymbolVarArray grads = RNNBackward::make( opr.input(0), opr.output(0), opr.input(1), out_grad.at(0), out_grad.at(1), opr.input(2), opr.output(2), opr.param()); // return grads.at(wrt_idx).node(); // input, hx, weights VarNodeArray ret(opr.input().size(), nullptr); for (size_t i = 0; i < ret.size(); ++i) { ret[i] = grads[i].node(); } return ret; } #endif /* ================= RNNBackward ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNNBackward); RNNBackward::RNNBackward( VarNode* x, VarNode* y, VarNode* hx, VarNode* dy, VarNode* dhy, VarNode* flatten_weights, VarNode* reserve_space, const Param& param, const OperatorNodeConfig& config) : Super({x->owner_graph(), config, "rnn_bwd", {x, y, hx, dy, dhy, flatten_weights, reserve_space}}, 0, true) { init_megdnn_opr(*this, param); add_input({x, y, hx, dy, dhy, flatten_weights, reserve_space}); } SymbolVarArray RNNBackward::make( SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar dy, SymbolVar dhy, SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param, const OperatorNodeConfig& config) { auto&& out = x.node()->owner_graph() ->insert_opr(std::make_unique( x.node(), y.node(), hx.node(), dy.node(), dhy.node(), flatten_weights.node(), reserve_space.node(), param, config)) ->output(); SymbolVarArray ret(out.size()); for (size_t i = 0; i < ret.size(); ++i) { ret[i] = out[i]; } return ret; } RNNBackward::Super::NodeProp* RNNBackward::do_make_node_prop() const { auto ret = Super::do_make_node_prop(); ret->add_dep_type_existing_var(input(6), NodeProp::DepType::VALUE_ALLOW_EMPTY); return ret; } void RNNBackward::init_output_static_infer_desc() { using namespace cg::static_infer; auto&& mgr = owner_graph()->static_infer_manager(); mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(5))); this->init_output_static_infer_desc_workspace( intl::AutoAddWorkspaceNeedLimitGetter::val); } void RNNBackward::init_output_dtype() { output(0)->dtype(input(0)->dtype()); output(1)->dtype(input(2)->dtype()); output(2)->dtype(input(5)->dtype()); } /* ================= LSTM ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTM); LSTMForward::LSTMForward( VarNode* input, VarNode* hx, VarNode* cx, VarNode* flatten_weights, const Param& param, const OperatorNodeConfig& config) : Super{input->owner_graph(), config, "lstm", {input, hx, cx, flatten_weights}} { init_megdnn_opr(*this, param); add_input({input, hx, cx, flatten_weights}); } SymbolVar LSTMForward::make( SymbolVar input, SymbolVar hx, SymbolVar cx, SymbolVar flatten_weights, const Param& param, const OperatorNodeConfig& config) { return input.insert_single_output_opr( input.node(), hx.node(), cx.node(), flatten_weights.node(), param, config); } #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LSTM) { SymbolVarArray grads = LSTMBackward::make( opr.input(0), opr.output(0), opr.input(1), opr.input(2), out_grad.at(0), out_grad.at(1), out_grad.at(2), opr.input(3), opr.output(3), opr.param()); SymbolVar res; return grads.at(wrt_idx).node(); // input, hx, cx, weights } #endif /* ================= LSTMBackward ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTMBackward); LSTMBackward::LSTMBackward( VarNode* x, VarNode* y, VarNode* hx, VarNode* cx, VarNode* dy, VarNode* dhy, VarNode* dcy, VarNode* flatten_weights, VarNode* reserve_space, const Param& param, const OperatorNodeConfig& config) : Super({x->owner_graph(), config, "lstm_bwd", {x, y, hx, cx, dy, dhy, dcy, flatten_weights, reserve_space}}, 1, true) { init_megdnn_opr(*this, param); add_input({x, y, hx, cx, dy, dhy, dcy, flatten_weights, reserve_space}); } SymbolVarArray LSTMBackward::make( SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar cx, SymbolVar dy, SymbolVar dhy, SymbolVar dcy, SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param, const OperatorNodeConfig& config) { auto&& out = x.node()->owner_graph() ->insert_opr(std::make_unique( x.node(), y.node(), hx.node(), cx.node(), dy.node(), dhy.node(), dcy.node(), flatten_weights.node(), reserve_space.node(), param, config)) ->output(); SymbolVarArray ret(out.size()); for (size_t i = 0; i < ret.size(); ++i) { ret[i] = out[i]; } return ret; } LSTMBackward::Super::NodeProp* LSTMBackward::do_make_node_prop() const { auto ret = Super::do_make_node_prop(); ret->add_dep_type_existing_var( input(8), // reserve space NodeProp::DepType::VALUE_ALLOW_EMPTY); return ret; } void LSTMBackward::init_output_static_infer_desc() { using namespace cg::static_infer; auto&& mgr = owner_graph()->static_infer_manager(); mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(3))); mgr.register_shape_infer(output(3), ShapeInferDesc::make_identity(input(7))); this->init_output_static_infer_desc_workspace( intl::AutoAddWorkspaceNeedLimitGetter::val); } void LSTMBackward::init_output_dtype() { output(0)->dtype(input(0)->dtype()); output(1)->dtype(input(2)->dtype()); output(2)->dtype(input(3)->dtype()); output(3)->dtype(input(7)->dtype()); }