You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.cpp 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. /**
  2. * \file imperative/src/impl/ops/rnn.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/dnn/rnn.h"
  12. #include "megbrain/imperative/ops/autogen.h"
  13. #include "../op_trait.h"
  14. namespace mgb::imperative {
  15. namespace rnn_cell {
  16. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  17. auto&& op = static_cast<const RNNCell&>(def);
  18. mgb_assert(inputs.size() == 6);
  19. return opr::RNNCell::make(
  20. inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5],
  21. op.param());
  22. }
  23. OP_TRAIT_REG(RNNCell, RNNCell).apply_on_var_node(apply_on_var_node).fallback();
  24. } // namespace rnn_cell
  25. namespace lstm_cell {
  26. VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  27. auto&& op = static_cast<const LSTMCell&>(def);
  28. mgb_assert(inputs.size() == 7);
  29. auto* opr = opr::LSTMCell::make(
  30. inputs[0], inputs[1], inputs[2], inputs[3], inputs[4],
  31. inputs[5], inputs[6], op.param())
  32. .node()
  33. ->owner_opr();
  34. return {opr->output(0), opr->output(1), opr->output(2)};
  35. }
  36. OP_TRAIT_REG(LSTMCell, LSTMCell).apply_on_var_node(apply_on_var_node).fallback();
  37. } // namespace lstm_cell
  38. namespace rnn {
  39. VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  40. auto&& op = static_cast<const RNN&>(def);
  41. mgb_assert(inputs.size() == 3);
  42. auto* opr = opr::RNN::make(inputs[0], inputs[1], inputs[2], op.param())
  43. .node()
  44. ->owner_opr();
  45. return {opr->output(0), opr->output(1), opr->output(2)};
  46. }
  47. OP_TRAIT_REG(RNN, RNN).apply_on_var_node(apply_on_var_node).fallback();
  48. } // namespace rnn
  49. namespace lstm {
  50. VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  51. auto&& op = static_cast<const LSTM&>(def);
  52. mgb_assert(inputs.size() == 4);
  53. auto* opr = opr::LSTM::make(inputs[0], inputs[1], inputs[2], inputs[3], op.param())
  54. .node()
  55. ->owner_opr();
  56. return {opr->output(0), opr->output(1), opr->output(2), opr->output(3)};
  57. }
  58. OP_TRAIT_REG(LSTM, LSTM).apply_on_var_node(apply_on_var_node).fallback();
  59. } // namespace lstm
  60. } // namespace mgb::imperative