| @@ -13,17 +13,17 @@ namespace Tensorflow.Keras.Layers | |||
| /// </summary> | |||
| public class Bidirectional: Wrapper | |||
| { | |||
| BidirectionalArgs _args; | |||
| RNN _forward_layer; | |||
| RNN _backward_layer; | |||
| RNN _layer; | |||
| bool _support_masking = true; | |||
| int _num_constants = 0; | |||
| bool _support_masking = true; | |||
| bool _return_state; | |||
| bool _stateful; | |||
| bool _return_sequences; | |||
| InputSpec _input_spec; | |||
| BidirectionalArgs _args; | |||
| RNNArgs _layer_args_copy; | |||
| RNN _forward_layer; | |||
| RNN _backward_layer; | |||
| RNN _layer; | |||
| InputSpec _input_spec; | |||
| public Bidirectional(BidirectionalArgs args):base(args) | |||
| { | |||
| _args = args; | |||
| @@ -66,12 +66,16 @@ namespace Tensorflow.Keras.Layers | |||
| // Recreate the forward layer from the original layer config, so that it | |||
| // will not carry over any state from the layer. | |||
| var actualType = _layer.GetType(); | |||
| if (actualType == typeof(LSTM)) | |||
| if (_layer is LSTM) | |||
| { | |||
| var arg = _layer_args_copy as LSTMArgs; | |||
| _forward_layer = new LSTM(arg); | |||
| } | |||
| else if(_layer is SimpleRNN) | |||
| { | |||
| var arg = _layer_args_copy as SimpleRNNArgs; | |||
| _forward_layer = new SimpleRNN(arg); | |||
| } | |||
| // TODO(Wanglongzhi2001), add GRU if case. | |||
| else | |||
| { | |||
| @@ -154,12 +158,18 @@ namespace Tensorflow.Keras.Layers | |||
| { | |||
| config.GoBackwards = !config.GoBackwards; | |||
| } | |||
| var actualType = layer.GetType(); | |||
| if (actualType == typeof(LSTM)) | |||
| if (layer is LSTM) | |||
| { | |||
| var arg = config as LSTMArgs; | |||
| return new LSTM(arg); | |||
| } | |||
| else if(layer is SimpleRNN) | |||
| { | |||
| var arg = config as SimpleRNNArgs; | |||
| return new SimpleRNN(arg); | |||
| } | |||
| // TODO(Wanglongzhi2001), add GRU if case. | |||
| else | |||
| { | |||
| return new RNN(cell, config); | |||
| @@ -183,7 +193,6 @@ namespace Tensorflow.Keras.Layers | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| // `Bidirectional.call` implements the same API as the wrapped `RNN`. | |||
| Tensors forward_inputs; | |||
| Tensors backward_inputs; | |||
| Tensors forward_state; | |||