| @@ -13,17 +13,17 @@ namespace Tensorflow.Keras.Layers | |||||
| /// </summary> | /// </summary> | ||||
| public class Bidirectional: Wrapper | public class Bidirectional: Wrapper | ||||
| { | { | ||||
| BidirectionalArgs _args; | |||||
| RNN _forward_layer; | |||||
| RNN _backward_layer; | |||||
| RNN _layer; | |||||
| bool _support_masking = true; | |||||
| int _num_constants = 0; | int _num_constants = 0; | ||||
| bool _support_masking = true; | |||||
| bool _return_state; | bool _return_state; | ||||
| bool _stateful; | bool _stateful; | ||||
| bool _return_sequences; | bool _return_sequences; | ||||
| InputSpec _input_spec; | |||||
| BidirectionalArgs _args; | |||||
| RNNArgs _layer_args_copy; | RNNArgs _layer_args_copy; | ||||
| RNN _forward_layer; | |||||
| RNN _backward_layer; | |||||
| RNN _layer; | |||||
| InputSpec _input_spec; | |||||
| public Bidirectional(BidirectionalArgs args):base(args) | public Bidirectional(BidirectionalArgs args):base(args) | ||||
| { | { | ||||
| _args = args; | _args = args; | ||||
| @@ -66,12 +66,16 @@ namespace Tensorflow.Keras.Layers | |||||
| // Recreate the forward layer from the original layer config, so that it | // Recreate the forward layer from the original layer config, so that it | ||||
| // will not carry over any state from the layer. | // will not carry over any state from the layer. | ||||
| var actualType = _layer.GetType(); | |||||
| if (actualType == typeof(LSTM)) | |||||
| if (_layer is LSTM) | |||||
| { | { | ||||
| var arg = _layer_args_copy as LSTMArgs; | var arg = _layer_args_copy as LSTMArgs; | ||||
| _forward_layer = new LSTM(arg); | _forward_layer = new LSTM(arg); | ||||
| } | } | ||||
| else if(_layer is SimpleRNN) | |||||
| { | |||||
| var arg = _layer_args_copy as SimpleRNNArgs; | |||||
| _forward_layer = new SimpleRNN(arg); | |||||
| } | |||||
| // TODO(Wanglongzhi2001), add GRU if case. | // TODO(Wanglongzhi2001), add GRU if case. | ||||
| else | else | ||||
| { | { | ||||
| @@ -154,12 +158,18 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| config.GoBackwards = !config.GoBackwards; | config.GoBackwards = !config.GoBackwards; | ||||
| } | } | ||||
| var actualType = layer.GetType(); | |||||
| if (actualType == typeof(LSTM)) | |||||
| if (layer is LSTM) | |||||
| { | { | ||||
| var arg = config as LSTMArgs; | var arg = config as LSTMArgs; | ||||
| return new LSTM(arg); | return new LSTM(arg); | ||||
| } | } | ||||
| else if(layer is SimpleRNN) | |||||
| { | |||||
| var arg = config as SimpleRNNArgs; | |||||
| return new SimpleRNN(arg); | |||||
| } | |||||
| // TODO(Wanglongzhi2001), add GRU if case. | |||||
| else | else | ||||
| { | { | ||||
| return new RNN(cell, config); | return new RNN(cell, config); | ||||
| @@ -183,7 +193,6 @@ namespace Tensorflow.Keras.Layers | |||||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | ||||
| { | { | ||||
| // `Bidirectional.call` implements the same API as the wrapped `RNN`. | // `Bidirectional.call` implements the same API as the wrapped `RNN`. | ||||
| Tensors forward_inputs; | Tensors forward_inputs; | ||||
| Tensors backward_inputs; | Tensors backward_inputs; | ||||
| Tensors forward_state; | Tensors forward_state; | ||||