| @@ -0,0 +1,20 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class BidirectionalArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("layer")] | |||||
| public ILayer Layer { get; set; } | |||||
| [JsonProperty("merge_mode")] | |||||
| public string? MergeMode { get; set; } | |||||
| [JsonProperty("backward_layer")] | |||||
| public ILayer BackwardLayer { get; set; } | |||||
| public NDArray Weights { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -5,5 +5,10 @@ | |||||
| // TODO: maybe change the `RNNArgs` and implement this class. | // TODO: maybe change the `RNNArgs` and implement this class. | ||||
| public bool UnitForgetBias { get; set; } | public bool UnitForgetBias { get; set; } | ||||
| public int Implementation { get; set; } | public int Implementation { get; set; } | ||||
| public LSTMArgs Clone() | |||||
| { | |||||
| return (LSTMArgs)MemberwiseClone(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -40,5 +40,10 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| public bool ZeroOutputForMask { get; set; } = false; | public bool ZeroOutputForMask { get; set; } = false; | ||||
| [JsonProperty("recurrent_dropout")] | [JsonProperty("recurrent_dropout")] | ||||
| public float RecurrentDropout { get; set; } = .0f; | public float RecurrentDropout { get; set; } = .0f; | ||||
| public RNNArgs Clone() | |||||
| { | |||||
| return (RNNArgs)MemberwiseClone(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,24 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class WrapperArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("layer")] | |||||
| public ILayer Layer { get; set; } | |||||
| public WrapperArgs(ILayer layer) | |||||
| { | |||||
| Layer = layer; | |||||
| } | |||||
| public static implicit operator WrapperArgs(BidirectionalArgs args) | |||||
| => new WrapperArgs(args.Layer); | |||||
| } | |||||
| } | |||||
| @@ -258,7 +258,19 @@ namespace Tensorflow.Keras.Layers | |||||
| float dropout = 0f, | float dropout = 0f, | ||||
| float recurrent_dropout = 0f, | float recurrent_dropout = 0f, | ||||
| bool reset_after = true); | bool reset_after = true); | ||||
| /// <summary> | |||||
| /// Bidirectional wrapper for RNNs. | |||||
| /// </summary> | |||||
| /// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param> | |||||
| /// automatically.</param> | |||||
| /// <returns></returns> | |||||
| public ILayer Bidirectional( | |||||
| ILayer layer, | |||||
| string merge_mode = "concat", | |||||
| NDArray weights = null, | |||||
| ILayer backward_layer = null); | |||||
| public ILayer Subtract(); | public ILayer Subtract(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -908,6 +908,20 @@ namespace Tensorflow.Keras.Layers | |||||
| ResetAfter = reset_after | ResetAfter = reset_after | ||||
| }); | }); | ||||
| public ILayer Bidirectional( | |||||
| ILayer layer, | |||||
| string merge_mode = "concat", | |||||
| NDArray weights = null, | |||||
| ILayer backward_layer = null) | |||||
| => new Bidirectional(new BidirectionalArgs | |||||
| { | |||||
| Layer = layer, | |||||
| MergeMode = merge_mode, | |||||
| Weights = weights, | |||||
| BackwardLayer = backward_layer | |||||
| }); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,33 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| /// <summary> | |||||
| /// Abstract wrapper base class. Wrappers take another layer and augment it in various ways. | |||||
| /// Do not use this class as a layer, it is only an abstract base class. | |||||
| /// Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers. | |||||
| /// </summary> | |||||
| public abstract class Wrapper: Layer | |||||
| { | |||||
| public ILayer _layer; | |||||
| public Wrapper(WrapperArgs args):base(args) | |||||
| { | |||||
| _layer = args.Layer; | |||||
| } | |||||
| public virtual void Build(KerasShapesWrapper input_shape) | |||||
| { | |||||
| if (!_layer.Built) | |||||
| { | |||||
| _layer.build(input_shape); | |||||
| } | |||||
| built = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,276 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Common.Types; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| /// <summary> | |||||
| /// Bidirectional wrapper for RNNs. | |||||
| /// </summary> | |||||
| public class Bidirectional: Wrapper | |||||
| { | |||||
| BidirectionalArgs _args; | |||||
| RNN _forward_layer; | |||||
| RNN _backward_layer; | |||||
| RNN _layer; | |||||
| bool _support_masking = true; | |||||
| int _num_constants = 0; | |||||
| bool _return_state; | |||||
| bool _stateful; | |||||
| bool _return_sequences; | |||||
| InputSpec _input_spec; | |||||
| RNNArgs _layer_args_copy; | |||||
| public Bidirectional(BidirectionalArgs args):base(args) | |||||
| { | |||||
| _args = args; | |||||
| if (_args.Layer is not ILayer) | |||||
| throw new ValueError( | |||||
| "Please initialize `Bidirectional` layer with a " + | |||||
| $"`tf.keras.layers.Layer` instance. Received: {_args.Layer}"); | |||||
| if (_args.BackwardLayer is not null && _args.BackwardLayer is not ILayer) | |||||
| throw new ValueError( | |||||
| "`backward_layer` need to be a `tf.keras.layers.Layer` " + | |||||
| $"instance. Received: {_args.BackwardLayer}"); | |||||
| if (!new List<string> { "sum", "mul", "ave", "concat", null }.Contains(_args.MergeMode)) | |||||
| { | |||||
| throw new ValueError( | |||||
| $"Invalid merge mode. Received: {_args.MergeMode}. " + | |||||
| "Merge mode should be one of " + | |||||
| "{\"sum\", \"mul\", \"ave\", \"concat\", null}" | |||||
| ); | |||||
| } | |||||
| if (_args.Layer is RNN) | |||||
| { | |||||
| _layer = _args.Layer as RNN; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new ValueError( | |||||
| "Bidirectional only support RNN instance such as LSTM or GRU"); | |||||
| } | |||||
| _return_state = _layer.Args.ReturnState; | |||||
| _return_sequences = _layer.Args.ReturnSequences; | |||||
| _stateful = _layer.Args.Stateful; | |||||
| _layer_args_copy = _layer.Args.Clone(); | |||||
| // We don't want to track `layer` since we're already tracking the two | |||||
| // copies of it we actually run. | |||||
| // TODO(Wanglongzhi2001), since the feature of setattr_tracking has not been implemented. | |||||
| // _setattr_tracking = false; | |||||
| // super().__init__(layer, **kwargs) | |||||
| // _setattr_tracking = true; | |||||
| // Recreate the forward layer from the original layer config, so that it | |||||
| // will not carry over any state from the layer. | |||||
| var actualType = _layer.GetType(); | |||||
| if (actualType == typeof(LSTM)) | |||||
| { | |||||
| var arg = _layer_args_copy as LSTMArgs; | |||||
| _forward_layer = new LSTM(arg); | |||||
| } | |||||
| // TODO(Wanglongzhi2001), add GRU if case. | |||||
| else | |||||
| { | |||||
| _forward_layer = new RNN(_layer.Cell, _layer_args_copy); | |||||
| } | |||||
| //_forward_layer = _recreate_layer_from_config(_layer); | |||||
| if (_args.BackwardLayer is null) | |||||
| { | |||||
| _backward_layer = _recreate_layer_from_config(_layer, go_backwards:true); | |||||
| } | |||||
| else | |||||
| { | |||||
| _backward_layer = _args.BackwardLayer as RNN; | |||||
| } | |||||
| _forward_layer.Name = "forward_" + _forward_layer.Name; | |||||
| _backward_layer.Name = "backward_" + _backward_layer.Name; | |||||
| _verify_layer_config(); | |||||
| void force_zero_output_for_mask(RNN layer) | |||||
| { | |||||
| layer.Args.ZeroOutputForMask = layer.Args.ReturnSequences; | |||||
| } | |||||
| force_zero_output_for_mask(_forward_layer); | |||||
| force_zero_output_for_mask(_backward_layer); | |||||
| if (_args.Weights is not null) | |||||
| { | |||||
| var nw = len(_args.Weights); | |||||
| _forward_layer.set_weights(_args.Weights[$":,{nw / 2}"]); | |||||
| _backward_layer.set_weights(_args.Weights[$"{nw / 2},:"]); | |||||
| } | |||||
| _input_spec = _layer.InputSpec; | |||||
| } | |||||
| private void _verify_layer_config() | |||||
| { | |||||
| if (_forward_layer.Args.GoBackwards == _backward_layer.Args.GoBackwards) | |||||
| { | |||||
| throw new ValueError( | |||||
| "Forward layer and backward layer should have different " + | |||||
| "`go_backwards` value." + | |||||
| "forward_layer.go_backwards = " + | |||||
| $"{_forward_layer.Args.GoBackwards}," + | |||||
| "backward_layer.go_backwards = " + | |||||
| $"{_backward_layer.Args.GoBackwards}"); | |||||
| } | |||||
| if (_forward_layer.Args.Stateful != _backward_layer.Args.Stateful) | |||||
| { | |||||
| throw new ValueError( | |||||
| "Forward layer and backward layer are expected to have "+ | |||||
| $"the same value for attribute stateful, got "+ | |||||
| $"{_forward_layer.Args.Stateful} for forward layer and "+ | |||||
| $"{_backward_layer.Args.Stateful} for backward layer"); | |||||
| } | |||||
| if (_forward_layer.Args.ReturnState != _backward_layer.Args.ReturnState) | |||||
| { | |||||
| throw new ValueError( | |||||
| "Forward layer and backward layer are expected to have " + | |||||
| $"the same value for attribute return_state, got " + | |||||
| $"{_forward_layer.Args.ReturnState} for forward layer and " + | |||||
| $"{_backward_layer.Args.ReturnState} for backward layer"); | |||||
| } | |||||
| if (_forward_layer.Args.ReturnSequences != _backward_layer.Args.ReturnSequences) | |||||
| { | |||||
| throw new ValueError( | |||||
| "Forward layer and backward layer are expected to have " + | |||||
| $"the same value for attribute return_sequences, got " + | |||||
| $"{_forward_layer.Args.ReturnSequences} for forward layer and " + | |||||
| $"{_backward_layer.Args.ReturnSequences} for backward layer"); | |||||
| } | |||||
| } | |||||
| private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false) | |||||
| { | |||||
| var config = layer.get_config() as RNNArgs; | |||||
| var cell = layer.Cell; | |||||
| if (go_backwards) | |||||
| { | |||||
| config.GoBackwards = !config.GoBackwards; | |||||
| } | |||||
| var actualType = layer.GetType(); | |||||
| if (actualType == typeof(LSTM)) | |||||
| { | |||||
| var arg = config as LSTMArgs; | |||||
| return new LSTM(arg); | |||||
| } | |||||
| else | |||||
| { | |||||
| return new RNN(cell, config); | |||||
| } | |||||
| } | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | |||||
| _buildInputShape = input_shape; | |||||
| tf_with(ops.name_scope(_forward_layer.Name), scope=> | |||||
| { | |||||
| _forward_layer.build(input_shape); | |||||
| }); | |||||
| tf_with(ops.name_scope(_backward_layer.Name), scope => | |||||
| { | |||||
| _backward_layer.build(input_shape); | |||||
| }); | |||||
| built = true; | |||||
| } | |||||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
| { | |||||
| // `Bidirectional.call` implements the same API as the wrapped `RNN`. | |||||
| Tensors forward_inputs; | |||||
| Tensors backward_inputs; | |||||
| Tensors forward_state; | |||||
| Tensors backward_state; | |||||
| // if isinstance(inputs, list) and len(inputs) > 1: | |||||
| if (inputs.Length > 1) | |||||
| { | |||||
| // initial_states are keras tensors, which means they are passed | |||||
| // in together with inputs as list. The initial_states need to be | |||||
| // split into forward and backward section, and be feed to layers | |||||
| // accordingly. | |||||
| forward_inputs = new Tensors { inputs[0] }; | |||||
| backward_inputs = new Tensors { inputs[0] }; | |||||
| var pivot = (len(inputs) - _num_constants) / 2 + 1; | |||||
| // add forward initial state | |||||
| forward_inputs.Concat(new Tensors { inputs[$"1:{pivot}"] }); | |||||
| if (_num_constants != 0) | |||||
| // add backward initial state | |||||
| backward_inputs.Concat(new Tensors { inputs[$"{pivot}:"] }); | |||||
| else | |||||
| { | |||||
| // add backward initial state | |||||
| backward_inputs.Concat(new Tensors { inputs[$"{pivot}:{-_num_constants}"] }); | |||||
| // add constants for forward and backward layers | |||||
| forward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] }); | |||||
| backward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] }); | |||||
| } | |||||
| forward_state = null; | |||||
| backward_state = null; | |||||
| } | |||||
| else if (state is not null) | |||||
| { | |||||
| // initial_states are not keras tensors, eg eager tensor from np | |||||
| // array. They are only passed in from kwarg initial_state, and | |||||
| // should be passed to forward/backward layer via kwarg | |||||
| // initial_state as well. | |||||
| forward_inputs = inputs; | |||||
| backward_inputs = inputs; | |||||
| var half = len(state) / 2; | |||||
| forward_state = state[$":{half}"]; | |||||
| backward_state = state[$"{half}:"]; | |||||
| } | |||||
| else | |||||
| { | |||||
| forward_inputs = inputs; | |||||
| backward_inputs = inputs; | |||||
| forward_state = null; | |||||
| backward_state = null; | |||||
| } | |||||
| var y = _forward_layer.Apply(forward_inputs, forward_state); | |||||
| var y_rev = _backward_layer.Apply(backward_inputs, backward_state); | |||||
| Tensors states = new(); | |||||
| if (_return_state) | |||||
| { | |||||
| states = y["1:"] + y_rev["1:"]; | |||||
| y = y[0]; | |||||
| y_rev = y_rev[0]; | |||||
| } | |||||
| if (_return_sequences) | |||||
| { | |||||
| int time_dim = _forward_layer.Args.TimeMajor ? 0 : 1; | |||||
| y_rev = keras.backend.reverse(y_rev, time_dim); | |||||
| } | |||||
| Tensors output; | |||||
| if (_args.MergeMode == "concat") | |||||
| output = keras.backend.concatenate(new Tensors { y.Single(), y_rev.Single() }); | |||||
| else if (_args.MergeMode == "sum") | |||||
| output = y.Single() + y_rev.Single(); | |||||
| else if (_args.MergeMode == "ave") | |||||
| output = (y.Single() + y_rev.Single()) / 2; | |||||
| else if (_args.MergeMode == "mul") | |||||
| output = y.Single() * y_rev.Single(); | |||||
| else if (_args.MergeMode is null) | |||||
| output = new Tensors { y.Single(), y_rev.Single() }; | |||||
| else | |||||
| throw new ValueError( | |||||
| "Unrecognized value for `merge_mode`. " + | |||||
| $"Received: {_args.MergeMode}" + | |||||
| "Expected values are [\"concat\", \"sum\", \"ave\", \"mul\"]"); | |||||
| if (_return_state) | |||||
| { | |||||
| if (_args.MergeMode is not null) | |||||
| return new Tensors { output.Single(), states.Single()}; | |||||
| } | |||||
| return output; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -3,6 +3,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
| using Tensorflow.Common.Extensions; | using Tensorflow.Common.Extensions; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| @@ -14,15 +15,15 @@ namespace Tensorflow.Keras.Layers | |||||
| /// </summary> | /// </summary> | ||||
| public class LSTM : RNN | public class LSTM : RNN | ||||
| { | { | ||||
| LSTMArgs args; | |||||
| LSTMArgs _args; | |||||
| InputSpec[] _state_spec; | InputSpec[] _state_spec; | ||||
| InputSpec _input_spec; | InputSpec _input_spec; | ||||
| bool _could_use_gpu_kernel; | bool _could_use_gpu_kernel; | ||||
| public LSTMArgs Args { get => _args; } | |||||
| public LSTM(LSTMArgs args) : | public LSTM(LSTMArgs args) : | ||||
| base(CreateCell(args), args) | base(CreateCell(args), args) | ||||
| { | { | ||||
| this.args = args; | |||||
| _args = args; | |||||
| _input_spec = new InputSpec(ndim: 3); | _input_spec = new InputSpec(ndim: 3); | ||||
| _state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray(); | _state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray(); | ||||
| _could_use_gpu_kernel = args.Activation == keras.activations.Tanh | _could_use_gpu_kernel = args.Activation == keras.activations.Tanh | ||||
| @@ -71,7 +72,7 @@ namespace Tensorflow.Keras.Layers | |||||
| var single_input = inputs.Single; | var single_input = inputs.Single; | ||||
| var input_shape = single_input.shape; | var input_shape = single_input.shape; | ||||
| var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1]; | |||||
| var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; | |||||
| _maybe_reset_cell_dropout_mask(Cell); | _maybe_reset_cell_dropout_mask(Cell); | ||||
| @@ -87,26 +88,26 @@ namespace Tensorflow.Keras.Layers | |||||
| inputs, | inputs, | ||||
| initial_state, | initial_state, | ||||
| constants: null, | constants: null, | ||||
| go_backwards: args.GoBackwards, | |||||
| go_backwards: _args.GoBackwards, | |||||
| mask: mask, | mask: mask, | ||||
| unroll: args.Unroll, | |||||
| unroll: _args.Unroll, | |||||
| input_length: ops.convert_to_tensor(timesteps), | input_length: ops.convert_to_tensor(timesteps), | ||||
| time_major: args.TimeMajor, | |||||
| zero_output_for_mask: args.ZeroOutputForMask, | |||||
| return_all_outputs: args.ReturnSequences | |||||
| time_major: _args.TimeMajor, | |||||
| zero_output_for_mask: _args.ZeroOutputForMask, | |||||
| return_all_outputs: _args.ReturnSequences | |||||
| ); | ); | ||||
| Tensor output; | Tensor output; | ||||
| if (args.ReturnSequences) | |||||
| if (_args.ReturnSequences) | |||||
| { | { | ||||
| output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, args.GoBackwards); | |||||
| output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, _args.GoBackwards); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| output = last_output; | output = last_output; | ||||
| } | } | ||||
| if (args.ReturnState) | |||||
| if (_args.ReturnState) | |||||
| { | { | ||||
| return new Tensor[] { output }.Concat(states).ToArray().ToTensors(); | return new Tensor[] { output }.Concat(states).ToArray().ToTensors(); | ||||
| } | } | ||||
| @@ -115,5 +116,11 @@ namespace Tensorflow.Keras.Layers | |||||
| return output; | return output; | ||||
| } | } | ||||
| } | } | ||||
| public override IKerasConfig get_config() | |||||
| { | |||||
| return _args; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -31,7 +31,9 @@ namespace Tensorflow.Keras.Layers | |||||
| protected IVariableV1 _kernel; | protected IVariableV1 _kernel; | ||||
| protected IVariableV1 _bias; | protected IVariableV1 _bias; | ||||
| private IRnnCell _cell; | private IRnnCell _cell; | ||||
| protected IRnnCell Cell | |||||
| public RNNArgs Args { get => _args; } | |||||
| public IRnnCell Cell | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| @@ -570,10 +572,13 @@ namespace Tensorflow.Keras.Layers | |||||
| var input_shape = array_ops.shape(inputs); | var input_shape = array_ops.shape(inputs); | ||||
| var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; | var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; | ||||
| var dtype = input.dtype; | var dtype = input.dtype; | ||||
| Tensors init_state = Cell.GetInitialState(null, batch_size, dtype); | Tensors init_state = Cell.GetInitialState(null, batch_size, dtype); | ||||
| return init_state; | return init_state; | ||||
| } | } | ||||
| public override IKerasConfig get_config() | |||||
| { | |||||
| return _args; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -5,6 +5,7 @@ using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
| using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| @@ -38,8 +39,6 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
| var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; | var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; | ||||
| var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); | var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); | ||||
| var (output, state) = stackedRNNCell.Apply(inputs, states); | var (output, state) = stackedRNNCell.Apply(inputs, states); | ||||
| Console.WriteLine(output); | |||||
| Console.WriteLine(state.shape); | |||||
| Assert.AreEqual((32, 5), output.shape); | Assert.AreEqual((32, 5), output.shape); | ||||
| Assert.AreEqual((32, 4), state[0].shape); | Assert.AreEqual((32, 4), state[0].shape); | ||||
| } | } | ||||
| @@ -108,6 +107,7 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
| var inputs = tf.random.normal((32, 10, 8)); | var inputs = tf.random.normal((32, 10, 8)); | ||||
| var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f); | var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f); | ||||
| var rnn = tf.keras.layers.RNN(cell: cell); | var rnn = tf.keras.layers.RNN(cell: cell); | ||||
| var cgf = rnn.get_config(); | |||||
| var output = rnn.Apply(inputs); | var output = rnn.Apply(inputs); | ||||
| Assert.AreEqual((32, 10), output.shape); | Assert.AreEqual((32, 10), output.shape); | ||||
| @@ -145,5 +145,14 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
| Assert.AreEqual((32, 4), output.shape); | Assert.AreEqual((32, 4), output.shape); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Bidirectional() | |||||
| { | |||||
| var bi = tf.keras.layers.Bidirectional(keras.layers.LSTM(10, return_sequences:true)); | |||||
| var inputs = tf.random.normal((32, 10, 8)); | |||||
| var outputs = bi.Apply(inputs); | |||||
| Assert.AreEqual((32, 10, 20), outputs.shape); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||