| @@ -2,6 +2,9 @@ | |||||
| { | { | ||||
| public class SimpleRNNArgs : RNNArgs | public class SimpleRNNArgs : RNNArgs | ||||
| { | { | ||||
| public float Dropout = 0f; | |||||
| public float RecurrentDropout = 0f; | |||||
| public int state_size; | |||||
| public int output_size; | |||||
| } | } | ||||
| } | } | ||||
| @@ -27,5 +27,11 @@ namespace Tensorflow.Keras | |||||
| TF_DataType DType { get; } | TF_DataType DType { get; } | ||||
| int count_params(); | int count_params(); | ||||
| void adapt(Tensor data, int? batch_size = null, int? steps = null); | void adapt(Tensor data, int? batch_size = null, int? steps = null); | ||||
| Tensors Call(Tensors inputs, Tensor? mask = null, bool? training = null, Tensors? initial_state = null, Tensors? constants = null); | |||||
| StateSizeWrapper state_size { get; } | |||||
| int output_size { get; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -200,6 +200,14 @@ namespace Tensorflow.Keras.Layers | |||||
| bool return_sequences = false, | bool return_sequences = false, | ||||
| bool return_state = false); | bool return_state = false); | ||||
| public ILayer SimpleRNNCell( | |||||
| int units, | |||||
| string activation = "tanh", | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros"); | |||||
| public ILayer Subtract(); | public ILayer Subtract(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -89,6 +89,8 @@ namespace Tensorflow | |||||
| protected bool built = false; | protected bool built = false; | ||||
| public bool Built => built; | public bool Built => built; | ||||
| StateSizeWrapper ILayer.state_size => throw new NotImplementedException(); | |||||
| public RnnCell(bool trainable = true, | public RnnCell(bool trainable = true, | ||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| @@ -174,5 +176,10 @@ namespace Tensorflow | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -19,6 +19,7 @@ using System; | |||||
| using System.Collections; | using System.Collections; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | |||||
| namespace Tensorflow.Util | namespace Tensorflow.Util | ||||
| { | { | ||||
| @@ -213,6 +214,17 @@ namespace Tensorflow.Util | |||||
| public static bool is_nested(object obj) | public static bool is_nested(object obj) | ||||
| { | { | ||||
| // Refer to https://www.tensorflow.org/api_docs/python/tf/nest | |||||
| //if (obj is IList || obj is IDictionary || obj is ITuple) | |||||
| // return true; | |||||
| if (obj is IList || obj is IDictionary) | |||||
| return true; | |||||
| if (obj is NDArray || obj is Tensor || obj is string || obj.GetType().IsGenericType | |||||
| || obj is ISet<int> || obj is ISet<float> || obj is ISet<double>) | |||||
| return false; | |||||
| if (obj.GetType().IsNested) return true; | |||||
| // Check if the object is an IEnumerable | // Check if the object is an IEnumerable | ||||
| if (obj is IEnumerable) | if (obj is IEnumerable) | ||||
| { | { | ||||
| @@ -244,7 +256,13 @@ namespace Tensorflow.Util | |||||
| _flatten_recursive(structure, list); | _flatten_recursive(structure, list); | ||||
| return list; | return list; | ||||
| } | } | ||||
| // TODO(Wanglongzhi2001), ITuple must used in .NET standard 2.1, but now is 2.0 | |||||
| // If you want to flatten a nested tuple, please specify the type of the tuple | |||||
| //public static List<T> flatten<T>(ITuple structure) | |||||
| //{ | |||||
| // var list = FlattenTuple<T>(structure).ToList(); | |||||
| // return list; | |||||
| //} | |||||
| public static List<T> flatten<T>(IEnumerable<T> structure) | public static List<T> flatten<T>(IEnumerable<T> structure) | ||||
| { | { | ||||
| var list = new List<T>(); | var list = new List<T>(); | ||||
| @@ -272,9 +290,13 @@ namespace Tensorflow.Util | |||||
| case String str: | case String str: | ||||
| list.Add(obj); | list.Add(obj); | ||||
| break; | break; | ||||
| case NDArray nd: | |||||
| // This case can hold both Tensor and NDArray | |||||
| case Tensor tensor: | |||||
| list.Add(obj); | list.Add(obj); | ||||
| break; | break; | ||||
| //case NDArray nd: | |||||
| // list.Add(obj); | |||||
| // break; | |||||
| case IEnumerable structure: | case IEnumerable structure: | ||||
| foreach (var child in structure) | foreach (var child in structure) | ||||
| _flatten_recursive((T)child, list); | _flatten_recursive((T)child, list); | ||||
| @@ -285,28 +307,26 @@ namespace Tensorflow.Util | |||||
| } | } | ||||
| } | } | ||||
| public static List<T> FlattenTupple<T>(object tuple) | |||||
| private static IEnumerable<T> FlattenTuple<T>(object tuple) | |||||
| { | { | ||||
| List<T> items = new List<T>(); | |||||
| var type = tuple.GetType(); | |||||
| if (type.GetInterface("ITuple") == null) | |||||
| throw new ArgumentException("This is not a tuple!"); | |||||
| //if (tuple is ITuple t) | |||||
| //{ | |||||
| // for (int i = 0; i < t.Length; i++) | |||||
| // { | |||||
| // foreach (var item in FlattenTuple<T>(t[i])) | |||||
| // { | |||||
| // yield return item; | |||||
| // } | |||||
| // } | |||||
| //} | |||||
| if(false) | |||||
| { | |||||
| foreach (var property in type.GetProperties()) | |||||
| } | |||||
| else | |||||
| { | { | ||||
| var value = property.GetValue(tuple); | |||||
| if (property.PropertyType.GetInterface("ITuple") != null) | |||||
| { | |||||
| var subItems = FlattenTupple<T>(value); | |||||
| items.AddRange(subItems); | |||||
| } | |||||
| else | |||||
| { | |||||
| items.Add((T)value); | |||||
| } | |||||
| yield return (T)tuple; | |||||
| } | } | ||||
| return items; | |||||
| } | } | ||||
| //# See the swig file (util.i) for documentation. | //# See the swig file (util.i) for documentation. | ||||
| //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples | //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples | ||||
| @@ -494,8 +514,12 @@ namespace Tensorflow.Util | |||||
| throw new ArgumentException("flat_sequence must not be null"); | throw new ArgumentException("flat_sequence must not be null"); | ||||
| // if not is_sequence(flat_sequence): | // if not is_sequence(flat_sequence): | ||||
| // raise TypeError("flat_sequence must be a sequence") | // raise TypeError("flat_sequence must be a sequence") | ||||
| if (!is_sequence(structure)) | |||||
| if (!is_nested(flat_sequence)) | |||||
| { | |||||
| throw new ArrayTypeMismatchException($"Attempted to pack value:\\n {flat_sequence}\\ninto a structure, " + | |||||
| $"but found incompatible type `{flat_sequence.GetType()}` instead."); | |||||
| } | |||||
| if (!is_nested(structure)) | |||||
| { | { | ||||
| if (len(flat) != 1) | if (len(flat) != 1) | ||||
| throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat)} > 1"); | throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat)} > 1"); | ||||
| @@ -614,331 +614,331 @@ namespace Tensorflow.Keras | |||||
| return nest.pack_sequence_as(inputs, inp); | return nest.pack_sequence_as(inputs, inp); | ||||
| } | } | ||||
| if (mask != null) | |||||
| { | |||||
| var mask_list = tf.unstack(mask); | |||||
| if (go_backwards) | |||||
| { | |||||
| mask_list.Reverse(); | |||||
| } | |||||
| for (int i = 0; i < time_steps; i++) | |||||
| { | |||||
| // TODO(Wanglongzhi2001),deal with _get_input_tensor | |||||
| var inp = _get_input_tensor(i); | |||||
| var mask_t = mask_list[i]; | |||||
| // TODO | |||||
| var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); | |||||
| var tiled_mask_t = _expand_mask(mask_t, output); | |||||
| Tensors prev_output; | |||||
| if (successive_outputs == null) | |||||
| { | |||||
| prev_output = tf.zeros_like(output); | |||||
| } | |||||
| else | |||||
| { | |||||
| prev_output = successive_outputs[successive_outputs.Length - 1]; | |||||
| } | |||||
| output = tf.where(tiled_mask_t, output, prev_output); | |||||
| //var flat_states = nest.flatten(states); | |||||
| //var flat_new_states = nest.flatten(newStates); | |||||
| var flat_states = states.ToList(); | |||||
| var flat_new_states = newStates.ToList(); | |||||
| var tiledMaskT = flat_states | |||||
| .Select(s => _expand_mask(mask_t, s)) | |||||
| .ToArray(); | |||||
| var tuple = Tuple.Create(tiledMaskT); | |||||
| List<Tensor> flat_final_states = new List<Tensor>(); | |||||
| foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) | |||||
| { | |||||
| flat_final_states.Add(tf.where(m, s, ps)); | |||||
| } | |||||
| states = (Tensors)nest.pack_sequence_as(states, flat_final_states); | |||||
| if (return_all_outputs) | |||||
| { | |||||
| successive_outputs.Add(output); | |||||
| successive_states.Add(states); | |||||
| } | |||||
| else | |||||
| { | |||||
| successive_outputs = new Tensors { output }; | |||||
| successive_states = new Tensors { states }; | |||||
| } | |||||
| } | |||||
| last_output = successive_outputs[successive_outputs.Length - 1]; | |||||
| new_states = successive_states[successive_states.Length - 1]; | |||||
| outputs = tf.stack(successive_outputs); | |||||
| if (zero_output_for_mask) | |||||
| { | |||||
| last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); | |||||
| outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); | |||||
| } | |||||
| else // mask is null | |||||
| { | |||||
| for (int i = 0; i < time_steps; i++) | |||||
| { | |||||
| var inp = _get_input_tensor(i); | |||||
| var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); | |||||
| states = newStates; | |||||
| if (return_all_outputs) | |||||
| { | |||||
| successive_outputs.Add(output); | |||||
| successive_states.Add(newStates); | |||||
| } | |||||
| else | |||||
| { | |||||
| successive_outputs = new Tensors { output }; | |||||
| successive_states = new Tensors { newStates }; | |||||
| } | |||||
| } | |||||
| last_output = successive_outputs[successive_outputs.Length - 1]; | |||||
| new_states = successive_states[successive_states.Length - 1]; | |||||
| outputs = tf.stack(successive_outputs); | |||||
| } | |||||
| } | |||||
| } | |||||
| else // unroll == false | |||||
| { | |||||
| var states = initial_states; | |||||
| // Create input tensor array, if the inputs is nested tensors, then it | |||||
| // will be flattened first, and tensor array will be created one per | |||||
| // flattened tensor. | |||||
| var input_ta = new List<TensorArray>(); | |||||
| for (int i = 0; i < flatted_inptus.Count; i++) | |||||
| { | |||||
| input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); | |||||
| } | |||||
| // Get the time(0) input and compute the output for that, the output will | |||||
| // be used to determine the dtype of output tensor array. Don't read from | |||||
| // input_ta due to TensorArray clear_after_read default to True. | |||||
| var inps = new Tensors(); | |||||
| foreach (var inp in flatted_inptus) | |||||
| { | |||||
| inps.Add(inp[0]); | |||||
| } | |||||
| var input_time_zero = nest.pack_sequence_as(inputs, inps); | |||||
| // output_time_zero is used to determine the cell output shape and its | |||||
| // dtype. the value is discarded. | |||||
| (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); | |||||
| var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); | |||||
| var output_ta = new List<TensorArray>(); | |||||
| for (int i = 0; i < output_time_zero.ToList().Count; i++) | |||||
| { | |||||
| var Out = output_time_zero.ToList()[i]; | |||||
| output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); | |||||
| } | |||||
| var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); | |||||
| Func<Tensor, Tensor>? masking_fn; | |||||
| Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null; | |||||
| if (mask != null) | |||||
| { | |||||
| if (go_backwards) | |||||
| { | |||||
| mask = tf.reverse(mask, axis: new[] { 0 }); | |||||
| } | |||||
| var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); | |||||
| mask_ta = mask_ta.unstack(mask); | |||||
| masking_fn = (time) => | |||||
| { | |||||
| return mask_ta.read(time); | |||||
| }; | |||||
| compute_masked_output = (mask_t, flat_out, flat_mask) => | |||||
| { | |||||
| var tiled_mask_t = new Tensors(); | |||||
| foreach (var o in flat_out) | |||||
| { | |||||
| tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); | |||||
| } | |||||
| Tensors res = new Tensors(); | |||||
| foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) | |||||
| { | |||||
| res.Add(tf.where(m, o, fm)); | |||||
| } | |||||
| return res; | |||||
| }; | |||||
| } | |||||
| // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? | |||||
| else if (input_length is Tensor) | |||||
| { | |||||
| if (go_backwards) | |||||
| { | |||||
| var max_len = tf.reduce_max(input_length, axis: 0); | |||||
| var rev_input_length = tf.subtract(max_len - 1, input_length); | |||||
| masking_fn = (time) => | |||||
| { | |||||
| return tf.less(rev_input_length, time); | |||||
| }; | |||||
| } | |||||
| else | |||||
| { | |||||
| masking_fn = (time) => | |||||
| { | |||||
| return tf.greater(input_length, time); | |||||
| }; | |||||
| } | |||||
| compute_masked_output = (mask_t, flat_out, flat_mask) => | |||||
| { | |||||
| var res = new List<Tensor>(); | |||||
| foreach (var (o, zo) in zip(flat_out, flat_mask)) | |||||
| { | |||||
| res.Add(tf.where(mask_t, o, zo)); | |||||
| } | |||||
| return res; | |||||
| }; | |||||
| } | |||||
| else | |||||
| { | |||||
| masking_fn = null; | |||||
| } | |||||
| if (masking_fn != null) | |||||
| { | |||||
| // Mask for the T output will be base on the output of T - 1. In the | |||||
| // case T = 0, a zero filled tensor will be used. | |||||
| var flat_zero_output = new Tensors(); | |||||
| foreach (var o in nest.flatten(output_time_zero)) | |||||
| { | |||||
| flat_zero_output.Add(tf.zeros_like(o)); | |||||
| } | |||||
| (Tensor, List<TensorArray>, Tensors, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors prev_output, Tensors states) | |||||
| { | |||||
| /* | |||||
| RNN step function. | |||||
| Args: | |||||
| time: Current timestep value. | |||||
| output_ta_t: TensorArray. | |||||
| prev_output: tuple of outputs from time - 1. | |||||
| *states: List of states. | |||||
| Returns: | |||||
| Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` | |||||
| */ | |||||
| var current_input = input_ta.Select(x => x.read(time)).ToList(); | |||||
| // maybe set shape | |||||
| // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||||
| current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input); | |||||
| var mask_t = masking_fn(time); | |||||
| var (output, new_states) = step_function(current_input, new Tensors { states, constants }); | |||||
| // mask output | |||||
| //var flat_output = nest.flatten(output); | |||||
| var flat_output = output.ToList(); | |||||
| var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); | |||||
| // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type | |||||
| var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); | |||||
| // mask states | |||||
| var flat_state = states.ToList(); | |||||
| var flat_new_state = new_states.ToList(); | |||||
| foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||||
| { | |||||
| if (new_state is Tensor) | |||||
| { | |||||
| new_state.set_shape(state.shape); | |||||
| } | |||||
| } | |||||
| var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); | |||||
| new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); | |||||
| var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||||
| var Output_ta_t = new List<TensorArray>(); | |||||
| // TODO(Wanglongzhi2001),deal with zip output_ta_t | |||||
| foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) | |||||
| { | |||||
| Output_ta_t.Add(ta.write(ta_index_to_write, Out)); | |||||
| } | |||||
| //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); | |||||
| return (time + 1, Output_ta_t, flat_new_output, new_states); | |||||
| } | |||||
| Func<Tensor, Tensor> cond = (time) => (time < time_step_t); | |||||
| var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); | |||||
| new_states = final_outputs.Item4; | |||||
| output_ta = final_outputs.Item2; | |||||
| } | |||||
| else | |||||
| { | |||||
| (Tensor, List<TensorArray>, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors states) | |||||
| { | |||||
| var current_input = input_ta.Select(x => x.read(time)).ToList(); | |||||
| // maybe set shape | |||||
| // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||||
| current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input); | |||||
| var (output, new_states) = step_function(current_input, new Tensors { states, constants }); | |||||
| var flat_state = states.ToList(); | |||||
| var flat_new_state = new_states.ToList(); | |||||
| foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||||
| { | |||||
| if (new_state is Tensor) | |||||
| { | |||||
| new_state.set_shape(state.shape); | |||||
| } | |||||
| } | |||||
| var flat_output = output.ToList(); | |||||
| var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||||
| var Output_ta_t = new List<TensorArray>(); | |||||
| foreach (var (ta, out_) in zip(output_ta_t, flat_output)) | |||||
| { | |||||
| Output_ta_t.Add(ta.write(ta_index_to_write, out_)); | |||||
| } | |||||
| new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); | |||||
| return (time + 1, Output_ta_t, new_states); | |||||
| } | |||||
| Func<Tensor, Tensor> cond = (time) => (time < time_step_t); | |||||
| var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); | |||||
| new_states = final_outputs.Item3; | |||||
| output_ta = final_outputs.Item2; | |||||
| } | |||||
| //Tensors outputs = new Tensors(); | |||||
| foreach (var o in output_ta) | |||||
| { | |||||
| outputs.Add(o.stack()); | |||||
| } | |||||
| foreach (var o in outputs) | |||||
| { | |||||
| last_output.Add(o[-1]); | |||||
| } | |||||
| outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); | |||||
| last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); | |||||
| //if (mask != null) | |||||
| //{ | |||||
| // var mask_list = tf.unstack(mask); | |||||
| // if (go_backwards) | |||||
| // { | |||||
| // mask_list.Reverse(); | |||||
| // } | |||||
| // for (int i = 0; i < time_steps; i++) | |||||
| // { | |||||
| // // TODO(Wanglongzhi2001),deal with _get_input_tensor | |||||
| // var inp = _get_input_tensor(i); | |||||
| // var mask_t = mask_list[i]; | |||||
| // // TODO | |||||
| // var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); | |||||
| // var tiled_mask_t = _expand_mask(mask_t, output); | |||||
| // Tensors prev_output; | |||||
| // if (successive_outputs == null) | |||||
| // { | |||||
| // prev_output = tf.zeros_like(output); | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // prev_output = successive_outputs[successive_outputs.Length - 1]; | |||||
| // } | |||||
| // output = tf.where(tiled_mask_t, output, prev_output); | |||||
| // //var flat_states = nest.flatten(states); | |||||
| // //var flat_new_states = nest.flatten(newStates); | |||||
| // var flat_states = states.ToList(); | |||||
| // var flat_new_states = newStates.ToList(); | |||||
| // var tiledMaskT = flat_states | |||||
| // .Select(s => _expand_mask(mask_t, s)) | |||||
| // .ToArray(); | |||||
| // var tuple = Tuple.Create(tiledMaskT); | |||||
| // List<Tensor> flat_final_states = new List<Tensor>(); | |||||
| // foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) | |||||
| // { | |||||
| // flat_final_states.Add(tf.where(m, s, ps)); | |||||
| // } | |||||
| // states = (Tensors)nest.pack_sequence_as(states, flat_final_states); | |||||
| // if (return_all_outputs) | |||||
| // { | |||||
| // successive_outputs.Add(output); | |||||
| // successive_states.Add(states); | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // successive_outputs = new Tensors { output }; | |||||
| // successive_states = new Tensors { states }; | |||||
| // } | |||||
| // } | |||||
| // last_output = successive_outputs[successive_outputs.Length - 1]; | |||||
| // new_states = successive_states[successive_states.Length - 1]; | |||||
| // outputs = tf.stack(successive_outputs); | |||||
| // if (zero_output_for_mask) | |||||
| // { | |||||
| // last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); | |||||
| // outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); | |||||
| // } | |||||
| // else // mask is null | |||||
| // { | |||||
| // for (int i = 0; i < time_steps; i++) | |||||
| // { | |||||
| // var inp = _get_input_tensor(i); | |||||
| // var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); | |||||
| // states = newStates; | |||||
| // if (return_all_outputs) | |||||
| // { | |||||
| // successive_outputs.Add(output); | |||||
| // successive_states.Add(newStates); | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // successive_outputs = new Tensors { output }; | |||||
| // successive_states = new Tensors { newStates }; | |||||
| // } | |||||
| // } | |||||
| // last_output = successive_outputs[successive_outputs.Length - 1]; | |||||
| // new_states = successive_states[successive_states.Length - 1]; | |||||
| // outputs = tf.stack(successive_outputs); | |||||
| // } | |||||
| //} | |||||
| } | } | ||||
| //else // unroll == false | |||||
| //{ | |||||
| // var states = initial_states; | |||||
| // // Create input tensor array, if the inputs is nested tensors, then it | |||||
| // // will be flattened first, and tensor array will be created one per | |||||
| // // flattened tensor. | |||||
| // var input_ta = new List<TensorArray>(); | |||||
| // for (int i = 0; i < flatted_inptus.Count; i++) | |||||
| // { | |||||
| // input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); | |||||
| // } | |||||
| // // Get the time(0) input and compute the output for that, the output will | |||||
| // // be used to determine the dtype of output tensor array. Don't read from | |||||
| // // input_ta due to TensorArray clear_after_read default to True. | |||||
| // var inps = new Tensors(); | |||||
| // foreach (var inp in flatted_inptus) | |||||
| // { | |||||
| // inps.Add(inp[0]); | |||||
| // } | |||||
| // var input_time_zero = nest.pack_sequence_as(inputs, inps); | |||||
| // // output_time_zero is used to determine the cell output shape and its | |||||
| // // dtype. the value is discarded. | |||||
| // (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); | |||||
| // var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); | |||||
| // var output_ta = new List<TensorArray>(); | |||||
| // for (int i = 0; i < output_time_zero.ToList().Count; i++) | |||||
| // { | |||||
| // var Out = output_time_zero.ToList()[i]; | |||||
| // output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); | |||||
| // } | |||||
| // var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); | |||||
| // Func<Tensor, Tensor>? masking_fn; | |||||
| // Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null; | |||||
| // if (mask != null) | |||||
| // { | |||||
| // if (go_backwards) | |||||
| // { | |||||
| // mask = tf.reverse(mask, axis: new[] { 0 }); | |||||
| // } | |||||
| // var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); | |||||
| // mask_ta = mask_ta.unstack(mask); | |||||
| // masking_fn = (time) => | |||||
| // { | |||||
| // return mask_ta.read(time); | |||||
| // }; | |||||
| // compute_masked_output = (mask_t, flat_out, flat_mask) => | |||||
| // { | |||||
| // var tiled_mask_t = new Tensors(); | |||||
| // foreach (var o in flat_out) | |||||
| // { | |||||
| // tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); | |||||
| // } | |||||
| // Tensors res = new Tensors(); | |||||
| // foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) | |||||
| // { | |||||
| // res.Add(tf.where(m, o, fm)); | |||||
| // } | |||||
| // return res; | |||||
| // }; | |||||
| // } | |||||
| // // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? | |||||
| // else if (input_length is Tensor) | |||||
| // { | |||||
| // if (go_backwards) | |||||
| // { | |||||
| // var max_len = tf.reduce_max(input_length, axis: 0); | |||||
| // var rev_input_length = tf.subtract(max_len - 1, input_length); | |||||
| // masking_fn = (time) => | |||||
| // { | |||||
| // return tf.less(rev_input_length, time); | |||||
| // }; | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // masking_fn = (time) => | |||||
| // { | |||||
| // return tf.greater(input_length, time); | |||||
| // }; | |||||
| // } | |||||
| // compute_masked_output = (mask_t, flat_out, flat_mask) => | |||||
| // { | |||||
| // var res = new List<Tensor>(); | |||||
| // foreach (var (o, zo) in zip(flat_out, flat_mask)) | |||||
| // { | |||||
| // res.Add(tf.where(mask_t, o, zo)); | |||||
| // } | |||||
| // return res; | |||||
| // }; | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // masking_fn = null; | |||||
| // } | |||||
| // if (masking_fn != null) | |||||
| // { | |||||
| // // Mask for the T output will be base on the output of T - 1. In the | |||||
| // // case T = 0, a zero filled tensor will be used. | |||||
| // var flat_zero_output = new Tensors(); | |||||
| // foreach (var o in nest.flatten(output_time_zero)) | |||||
| // { | |||||
| // flat_zero_output.Add(tf.zeros_like(o)); | |||||
| // } | |||||
| // (Tensor, List<TensorArray>, Tensors, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors prev_output, Tensors states) | |||||
| // { | |||||
| // /* | |||||
| // RNN step function. | |||||
| // Args: | |||||
| // time: Current timestep value. | |||||
| // output_ta_t: TensorArray. | |||||
| // prev_output: tuple of outputs from time - 1. | |||||
| // *states: List of states. | |||||
| // Returns: | |||||
| // Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` | |||||
| // */ | |||||
| // var current_input = input_ta.Select(x => x.read(time)).ToList(); | |||||
| // // maybe set shape | |||||
| // // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||||
| // current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input); | |||||
| // var mask_t = masking_fn(time); | |||||
| // var (output, new_states) = step_function(current_input, new Tensors { states, constants }); | |||||
| // // mask output | |||||
| // //var flat_output = nest.flatten(output); | |||||
| // var flat_output = output.ToList(); | |||||
| // var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); | |||||
| // // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type | |||||
| // var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); | |||||
| // // mask states | |||||
| // var flat_state = states.ToList(); | |||||
| // var flat_new_state = new_states.ToList(); | |||||
| // foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||||
| // { | |||||
| // if (new_state is Tensor) | |||||
| // { | |||||
| // new_state.set_shape(state.shape); | |||||
| // } | |||||
| // } | |||||
| // var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); | |||||
| // new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); | |||||
| // var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||||
| // var Output_ta_t = new List<TensorArray>(); | |||||
| // // TODO(Wanglongzhi2001),deal with zip output_ta_t | |||||
| // foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) | |||||
| // { | |||||
| // Output_ta_t.Add(ta.write(ta_index_to_write, Out)); | |||||
| // } | |||||
| // //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); | |||||
| // return (time + 1, Output_ta_t, flat_new_output, new_states); | |||||
| // } | |||||
| // Func<Tensor, Tensor> cond = (time) => (time < time_step_t); | |||||
| // var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); | |||||
| // new_states = final_outputs.Item4; | |||||
| // output_ta = final_outputs.Item2; | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // (Tensor, List<TensorArray>, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors states) | |||||
| // { | |||||
| // var current_input = input_ta.Select(x => x.read(time)).ToList(); | |||||
| // // maybe set shape | |||||
| // // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||||
| // current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input); | |||||
| // var (output, new_states) = step_function(current_input, new Tensors { states, constants }); | |||||
| // var flat_state = states.ToList(); | |||||
| // var flat_new_state = new_states.ToList(); | |||||
| // foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||||
| // { | |||||
| // if (new_state is Tensor) | |||||
| // { | |||||
| // new_state.set_shape(state.shape); | |||||
| // } | |||||
| // } | |||||
| // var flat_output = output.ToList(); | |||||
| // var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||||
| // var Output_ta_t = new List<TensorArray>(); | |||||
| // foreach (var (ta, out_) in zip(output_ta_t, flat_output)) | |||||
| // { | |||||
| // Output_ta_t.Add(ta.write(ta_index_to_write, out_)); | |||||
| // } | |||||
| // new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); | |||||
| // return (time + 1, Output_ta_t, new_states); | |||||
| // } | |||||
| // Func<Tensor, Tensor> cond = (time) => (time < time_step_t); | |||||
| // var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); | |||||
| // new_states = final_outputs.Item3; | |||||
| // output_ta = final_outputs.Item2; | |||||
| // } | |||||
| // //Tensors outputs = new Tensors(); | |||||
| // foreach (var o in output_ta) | |||||
| // { | |||||
| // outputs.Add(o.stack()); | |||||
| // } | |||||
| // foreach (var o in outputs) | |||||
| // { | |||||
| // last_output.Add(o[-1]); | |||||
| // } | |||||
| // outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); | |||||
| // last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); | |||||
| //} | |||||
| Func<Tensor, Tensor> set_shape; | Func<Tensor, Tensor> set_shape; | ||||
| set_shape = (output_) => | set_shape = (output_) => | ||||
| @@ -968,5 +968,27 @@ namespace Tensorflow.Keras | |||||
| return (last_output, Outputs, new_states); | return (last_output, Outputs, new_states); | ||||
| } | } | ||||
| // Multiplies 2 tensors (and/or variables) and returns a tensor. | |||||
| // This operation corresponds to `numpy.dot(a, b, out=None)`. | |||||
| public Tensor Dot(Tensor x, Tensor y) | |||||
| { | |||||
| //if (x.ndim != 1 && (x.ndim > 2 || y.ndim > 2)) | |||||
| //{ | |||||
| // var x_shape = new List<int>(); | |||||
| // foreach (var (i,s) in zip(x.shape.as_int_list(), tf.unstack(tf.shape(x)))) | |||||
| // { | |||||
| // if (i != 0) | |||||
| // { | |||||
| // x_shape.append(i); | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // x_shape.append(s); | |||||
| // } | |||||
| // } | |||||
| //} | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -325,7 +325,7 @@ namespace Tensorflow.Keras.Engine | |||||
| nodes_in_decreasing_depth.append(node); | nodes_in_decreasing_depth.append(node); | ||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
| protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
| { | { | ||||
| var tensor_dict = new Dictionary<long, Queue<Tensor>>(); | var tensor_dict = new Dictionary<long, Queue<Tensor>>(); | ||||
| // map input values | // map input values | ||||
| @@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine | |||||
| if (!built) | if (!built) | ||||
| MaybeBuild(inputs); | MaybeBuild(inputs); | ||||
| var outputs = Call(inputs, state: state, training: training); | |||||
| var outputs = Call(inputs, initial_state: state, training: training); | |||||
| // memory leak | // memory leak | ||||
| // _set_connectivity_metadata_(inputs, outputs); | // _set_connectivity_metadata_(inputs, outputs); | ||||
| @@ -254,6 +254,10 @@ namespace Tensorflow.Keras.Engine | |||||
| /// </summary> | /// </summary> | ||||
| public Func<Tensors, Tensors>? ReplacedCall { get; set; } = null; | public Func<Tensors, Tensors>? ReplacedCall { get; set; } = null; | ||||
| public StateSizeWrapper state_size => throw new NotImplementedException(); | |||||
| public int output_size => throw new NotImplementedException(); | |||||
| public Layer(LayerArgs args) | public Layer(LayerArgs args) | ||||
| { | { | ||||
| Initialize(args); | Initialize(args); | ||||
| @@ -434,56 +438,61 @@ namespace Tensorflow.Keras.Engine | |||||
| public override void SetAttr(string name, object value) | public override void SetAttr(string name, object value) | ||||
| { | { | ||||
| // TODO(Rinne): deal with "_self_setattr_tracking". | |||||
| //// TODO(Rinne): deal with "_self_setattr_tracking". | |||||
| value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); | |||||
| //value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); | |||||
| foreach(var val in nest.flatten(value)) | |||||
| { | |||||
| if(val is Metric) | |||||
| { | |||||
| // TODO(Rinne): deal with metrics. | |||||
| } | |||||
| } | |||||
| // TODO(Rinne): deal with "_auto_track_sub_layers". | |||||
| foreach(var val in nest.flatten(value)) | |||||
| { | |||||
| if(val is not IVariableV1 variable) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| if (variable.Trainable) | |||||
| { | |||||
| if (_trainable_weights.Contains(variable)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| _trainable_weights.Add(variable); | |||||
| } | |||||
| else | |||||
| { | |||||
| if (_non_trainable_weights.Contains(variable)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| _non_trainable_weights.Add(variable); | |||||
| } | |||||
| keras.backend.track_variable(variable); | |||||
| } | |||||
| //foreach(var val in nest.flatten(value)) | |||||
| //{ | |||||
| // if(val is Metric) | |||||
| // { | |||||
| // // TODO(Rinne): deal with metrics. | |||||
| // } | |||||
| //} | |||||
| //// TODO(Rinne): deal with "_auto_track_sub_layers". | |||||
| //foreach(var val in nest.flatten(value)) | |||||
| //{ | |||||
| // if(val is not IVariableV1 variable) | |||||
| // { | |||||
| // continue; | |||||
| // } | |||||
| // if (variable.Trainable) | |||||
| // { | |||||
| // if (_trainable_weights.Contains(variable)) | |||||
| // { | |||||
| // continue; | |||||
| // } | |||||
| // _trainable_weights.Add(variable); | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // if (_non_trainable_weights.Contains(variable)) | |||||
| // { | |||||
| // continue; | |||||
| // } | |||||
| // _non_trainable_weights.Add(variable); | |||||
| // } | |||||
| // keras.backend.track_variable(variable); | |||||
| //} | |||||
| //// Directly use the implementation of `Trackable`. | |||||
| //var t = this.GetType(); | |||||
| //var field_info = t.GetField(name); | |||||
| //if (field_info is not null) | |||||
| //{ | |||||
| // field_info.SetValue(this, value); | |||||
| //} | |||||
| //else | |||||
| //{ | |||||
| // CustomizedFields[name] = value; | |||||
| //} | |||||
| } | |||||
| // Directly use the implementation of `Trackable`. | |||||
| var t = this.GetType(); | |||||
| var field_info = t.GetField(name); | |||||
| if (field_info is not null) | |||||
| { | |||||
| field_info.SetValue(this, value); | |||||
| } | |||||
| else | |||||
| { | |||||
| CustomizedFields[name] = value; | |||||
| } | |||||
| Tensors ILayer.Call(Tensors inputs, Tensor mask, bool? training, Tensors initial_state, Tensors constants) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
| protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
| { | { | ||||
| if (!_has_explicit_input_shape) | if (!_has_explicit_input_shape) | ||||
| { | { | ||||
| @@ -154,10 +154,10 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| if (!built) | if (!built) | ||||
| _init_graph_network(this.inputs, outputs); | _init_graph_network(this.inputs, outputs); | ||||
| return base.Call(inputs, state, training); | |||||
| return base.Call(inputs, initial_state, training); | |||||
| } | } | ||||
| return base.Call(inputs, state, training); | |||||
| return base.Call(inputs, initial_state, training); | |||||
| } | } | ||||
| void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) | void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) | ||||
| @@ -83,7 +83,7 @@ namespace Tensorflow.Keras.Layers | |||||
| _buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
| protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
| { | { | ||||
| var inputs_shape = array_ops.shape(inputs); | var inputs_shape = array_ops.shape(inputs); | ||||
| var batch_size = inputs_shape[0]; | var batch_size = inputs_shape[0]; | ||||
| @@ -709,6 +709,23 @@ namespace Tensorflow.Keras.Layers | |||||
| ReturnState = return_state | ReturnState = return_state | ||||
| }); | }); | ||||
| public ILayer SimpleRNNCell( | |||||
| int units, | |||||
| string activation = "tanh", | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros") | |||||
| => new SimpleRNNCell(new SimpleRNNArgs | |||||
| { | |||||
| Units = units, | |||||
| Activation = keras.activations.GetActivationFromName(activation), | |||||
| UseBias = use_bias, | |||||
| KernelInitializer = GetInitializerByName(kernel_initializer), | |||||
| RecurrentInitializer = GetInitializerByName(recurrent_initializer), | |||||
| } | |||||
| ); | |||||
| /// <summary> | /// <summary> | ||||
| /// Long Short-Term Memory layer - Hochreiter 1997. | /// Long Short-Term Memory layer - Hochreiter 1997. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,80 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Layers.Rnn | |||||
| { | |||||
| public class DropoutRNNCellMixin | |||||
| { | |||||
| public float dropout; | |||||
| public float recurrent_dropout; | |||||
| // Get the dropout mask for RNN cell's input. | |||||
| public Tensors get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | |||||
| { | |||||
| return _generate_dropout_mask( | |||||
| tf.ones_like(input), | |||||
| dropout, | |||||
| training, | |||||
| count); | |||||
| } | |||||
| // Get the recurrent dropout mask for RNN cell. | |||||
| public Tensors get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | |||||
| { | |||||
| return _generate_dropout_mask( | |||||
| tf.ones_like(input), | |||||
| recurrent_dropout, | |||||
| training, | |||||
| count); | |||||
| } | |||||
| public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1) | |||||
| { | |||||
| return _generate_dropout_mask( | |||||
| tf.ones_like(input), | |||||
| dropout, | |||||
| training, | |||||
| count); | |||||
| } | |||||
| public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1) | |||||
| { | |||||
| return _generate_dropout_mask( | |||||
| tf.ones_like(input), | |||||
| recurrent_dropout, | |||||
| training, | |||||
| count); | |||||
| } | |||||
| public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1) | |||||
| { | |||||
| Tensors dropped_inputs() | |||||
| { | |||||
| DropoutArgs args = new DropoutArgs(); | |||||
| args.Rate = rate; | |||||
| var DropoutLayer = new Dropout(args); | |||||
| var mask = DropoutLayer.Apply(ones, training: training); | |||||
| return mask; | |||||
| } | |||||
| if (count > 1) | |||||
| { | |||||
| Tensors results = new Tensors(); | |||||
| for (int i = 0; i < count; i++) | |||||
| { | |||||
| results.Add(dropped_inputs()); | |||||
| } | |||||
| return results; | |||||
| } | |||||
| return dropped_inputs(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| private RNNArgs args; | private RNNArgs args; | ||||
| private object input_spec = null; // or NoneValue?? | private object input_spec = null; // or NoneValue?? | ||||
| private object state_spec = null; | private object state_spec = null; | ||||
| private object _states = null; | |||||
| private Tensors _states = null; | |||||
| private object constants_spec = null; | private object constants_spec = null; | ||||
| private int _num_constants = 0; | private int _num_constants = 0; | ||||
| protected IVariableV1 kernel; | protected IVariableV1 kernel; | ||||
| @@ -44,19 +44,15 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| cell = args.Cell.AsT1; | cell = args.Cell.AsT1; | ||||
| } | } | ||||
| Type type = cell.GetType(); | Type type = cell.GetType(); | ||||
| MethodInfo methodInfo = type.GetMethod("Call"); | |||||
| if (methodInfo == null) | |||||
| MethodInfo callMethodInfo = type.GetMethod("Call"); | |||||
| if (callMethodInfo == null) | |||||
| { | { | ||||
| throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. "); | throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. "); | ||||
| } | } | ||||
| PropertyInfo propertyInfo = type.GetProperty("state_size"); | |||||
| if (propertyInfo == null) | |||||
| PropertyInfo state_size_info = type.GetProperty("state_size"); | |||||
| if (state_size_info == null) | |||||
| { | { | ||||
| throw new ValueError(@"The RNN cell should have a `state_size` attribute"); | throw new ValueError(@"The RNN cell should have a `state_size` attribute"); | ||||
| } | } | ||||
| @@ -80,7 +76,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) | // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) | ||||
| // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape | // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape | ||||
| public object States | |||||
| public Tensors States | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| @@ -106,7 +102,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| // state_size is a array of ints or a positive integer | // state_size is a array of ints or a positive integer | ||||
| var state_size = cell.state_size; | var state_size = cell.state_size; | ||||
| // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor | // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor | ||||
| Func<Shape, Shape> _get_output_shape; | Func<Shape, Shape> _get_output_shape; | ||||
| _get_output_shape = (flat_output_size) => | _get_output_shape = (flat_output_size) => | ||||
| @@ -132,8 +127,10 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| return output_shape; | return output_shape; | ||||
| }; | }; | ||||
| Type type = cell.GetType(); | |||||
| PropertyInfo output_size_info = type.GetProperty("output_size"); | |||||
| Shape output_shape; | Shape output_shape; | ||||
| if (cell.output_size != 0) | |||||
| if (output_size_info != null) | |||||
| { | { | ||||
| output_shape = nest.map_structure(_get_output_shape, cell.output_size); | output_shape = nest.map_structure(_get_output_shape, cell.output_size); | ||||
| // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 | // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 | ||||
| @@ -160,6 +157,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| { | { | ||||
| return output_shape; | return output_shape; | ||||
| } | } | ||||
| } | } | ||||
| private Tensors compute_mask(Tensors inputs, Tensors mask) | private Tensors compute_mask(Tensors inputs, Tensors mask) | ||||
| @@ -184,8 +182,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| { | { | ||||
| return output_mask; | return output_mask; | ||||
| } | } | ||||
| } | } | ||||
| public override void build(KerasShapesWrapper input_shape) | public override void build(KerasShapesWrapper input_shape) | ||||
| @@ -247,14 +243,18 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | ||||
| { | { | ||||
| //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); | //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); | ||||
| //bool is_ragged_input = row_length != null; | |||||
| //_validate_args_if_ragged(is_ragged_input, mask); | |||||
| var (inputs_processed, initial_state_processed, constants_processed) = _process_inputs(inputs, initial_state, constants); | |||||
| // 暂时先不接受ragged tensor | |||||
| int? row_length = null; | |||||
| bool is_ragged_input = false; | |||||
| _validate_args_if_ragged(is_ragged_input, mask); | |||||
| (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); | |||||
| _maybe_reset_cell_dropout_mask(cell); | _maybe_reset_cell_dropout_mask(cell); | ||||
| if (cell is StackedRNNCells) | if (cell is StackedRNNCells) | ||||
| { | { | ||||
| foreach (var cell in ((StackedRNNCells)cell).Cells) | |||||
| var stack_cell = cell as StackedRNNCells; | |||||
| foreach (var cell in stack_cell.Cells) | |||||
| { | { | ||||
| _maybe_reset_cell_dropout_mask(cell); | _maybe_reset_cell_dropout_mask(cell); | ||||
| } | } | ||||
| @@ -263,17 +263,16 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| if (mask != null) | if (mask != null) | ||||
| { | { | ||||
| // Time step masks must be the same for each input. | // Time step masks must be the same for each input. | ||||
| //mask = nest.flatten(mask)[0]; | |||||
| mask = mask[0]; | |||||
| mask = nest.flatten(mask)[0]; | |||||
| } | } | ||||
| Shape input_shape; | Shape input_shape; | ||||
| if (nest.is_nested(initial_state_processed)) | |||||
| if (nest.is_nested(inputs)) | |||||
| { | { | ||||
| // In the case of nested input, use the first element for shape check | // In the case of nested input, use the first element for shape check | ||||
| // input_shape = nest.flatten(inputs)[0].shape; | // input_shape = nest.flatten(inputs)[0].shape; | ||||
| input_shape = inputs[0].shape; | |||||
| // TODO(Wanglongzhi2001) | |||||
| input_shape = nest.flatten(inputs)[0].shape; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -322,6 +321,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) | // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) | ||||
| states = states.Length == 1 ? states[0] : states; | states = states.Length == 1 ? states[0] : states; | ||||
| var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); | var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); | ||||
| // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors? | |||||
| if (!nest.is_nested(new_states)) | if (!nest.is_nested(new_states)) | ||||
| { | { | ||||
| return (output, new Tensors { new_states }); | return (output, new Tensors { new_states }); | ||||
| @@ -351,7 +351,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| go_backwards: args.GoBackwards, | go_backwards: args.GoBackwards, | ||||
| mask: mask, | mask: mask, | ||||
| unroll: args.Unroll, | unroll: args.Unroll, | ||||
| input_length: row_length != null ? row_length : new Tensor(timesteps), | |||||
| input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps), | |||||
| time_major: args.TimeMajor, | time_major: args.TimeMajor, | ||||
| zero_output_for_mask: args.ZeroOutputForMask, | zero_output_for_mask: args.ZeroOutputForMask, | ||||
| return_all_outputs: args.ReturnSequences); | return_all_outputs: args.ReturnSequences); | ||||
| @@ -387,24 +387,9 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| } | } | ||||
| } | } | ||||
| private (Tensors, Tensors, Tensors) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants) | |||||
| private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants) | |||||
| { | { | ||||
| bool IsSequence(object obj) | |||||
| { | |||||
| // Check if the object is an IEnumerable | |||||
| if (obj is IEnumerable) | |||||
| { | |||||
| // If it is, check if it is a tuple | |||||
| if (!(obj is Tuple)) | |||||
| { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| // If it is not, return false | |||||
| return false; | |||||
| } | |||||
| if (IsSequence(input)) | |||||
| if (nest.is_sequence(input)) | |||||
| { | { | ||||
| if (_num_constants != 0) | if (_num_constants != 0) | ||||
| { | { | ||||
| @@ -413,6 +398,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| else | else | ||||
| { | { | ||||
| initial_state = inputs[new Slice(1, len(inputs) - _num_constants)]; | initial_state = inputs[new Slice(1, len(inputs) - _num_constants)]; | ||||
| constants = inputs[new Slice(len(inputs) - _num_constants, len(inputs))]; | |||||
| } | } | ||||
| if (len(initial_state) == 0) | if (len(initial_state) == 0) | ||||
| initial_state = null; | initial_state = null; | ||||
| @@ -421,32 +407,63 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| if (args.Stateful) | if (args.Stateful) | ||||
| { | { | ||||
| throw new NotImplementedException("argument stateful has not been implemented!"); | |||||
| if (initial_state != null) | |||||
| { | |||||
| var tmp = new Tensor[] { }; | |||||
| foreach (var s in nest.flatten(States)) | |||||
| { | |||||
| tmp.add(tf.math.count_nonzero((Tensor)s)); | |||||
| } | |||||
| var non_zero_count = tf.add_n(tmp); | |||||
| //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); | |||||
| if((int)non_zero_count.numpy() > 0) | |||||
| { | |||||
| initial_state = States; | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| initial_state = States; | |||||
| } | |||||
| } | } | ||||
| else if(initial_state != null) | |||||
| { | |||||
| initial_state = get_initial_state(inputs); | |||||
| } | |||||
| return (inputs, initial_state, constants); | |||||
| if (initial_state.Length != States.Length) | |||||
| { | |||||
| throw new ValueError( | |||||
| $"Layer {this} expects {States.Length} state(s), " + | |||||
| $"but it received {initial_state.Length} " + | |||||
| $"initial state(s). Input received: {inputs}"); | |||||
| } | |||||
| return (inputs, initial_state, constants); | |||||
| } | } | ||||
| private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) | private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) | ||||
| { | { | ||||
| if (is_ragged_input) | |||||
| if (!is_ragged_input) | |||||
| { | { | ||||
| if (args.Unroll) | |||||
| { | |||||
| throw new ValueError("The input received contains RaggedTensors and does " + | |||||
| "not support unrolling. Disable unrolling by passing " + | |||||
| "`unroll=False` in the RNN Layer constructor."); | |||||
| } | |||||
| if (mask != null) | |||||
| { | |||||
| throw new ValueError($"The mask that was passed in was {mask}, which " + | |||||
| "cannot be applied to RaggedTensor inputs. Please " + | |||||
| "make sure that there is no mask injected by upstream " + | |||||
| "layers."); | |||||
| } | |||||
| return; | |||||
| } | |||||
| if (args.Unroll) | |||||
| { | |||||
| throw new ValueError("The input received contains RaggedTensors and does " + | |||||
| "not support unrolling. Disable unrolling by passing " + | |||||
| "`unroll=False` in the RNN Layer constructor."); | |||||
| } | |||||
| if (mask != null) | |||||
| { | |||||
| throw new ValueError($"The mask that was passed in was {mask}, which " + | |||||
| "cannot be applied to RaggedTensor inputs. Please " + | |||||
| "make sure that there is no mask injected by upstream " + | |||||
| "layers."); | |||||
| } | } | ||||
| } | } | ||||
| void _maybe_reset_cell_dropout_mask(ILayer cell) | void _maybe_reset_cell_dropout_mask(ILayer cell) | ||||
| @@ -489,46 +506,77 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| public RNN New(LayerRnnCell cell, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false) | |||||
| => new RNN(new RNNArgs | |||||
| { | |||||
| Cell = cell, | |||||
| ReturnSequences = return_sequences, | |||||
| ReturnState = return_state, | |||||
| GoBackwards = go_backwards, | |||||
| Stateful = stateful, | |||||
| Unroll = unroll, | |||||
| TimeMajor = time_major | |||||
| }); | |||||
| public RNN New(IList<IRnnArgCell> cell, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false) | |||||
| => new RNN(new RNNArgs | |||||
| { | |||||
| Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }), | |||||
| ReturnSequences = return_sequences, | |||||
| ReturnState = return_state, | |||||
| GoBackwards = go_backwards, | |||||
| Stateful = stateful, | |||||
| Unroll = unroll, | |||||
| TimeMajor = time_major | |||||
| }); | |||||
| // 好像不能cell不能传接口类型 | |||||
| //public RNN New(IRnnArgCell cell, | |||||
| // bool return_sequences = false, | |||||
| // bool return_state = false, | |||||
| // bool go_backwards = false, | |||||
| // bool stateful = false, | |||||
| // bool unroll = false, | |||||
| // bool time_major = false) | |||||
| // => new RNN(new RNNArgs | |||||
| // { | |||||
| // Cell = cell, | |||||
| // ReturnSequences = return_sequences, | |||||
| // ReturnState = return_state, | |||||
| // GoBackwards = go_backwards, | |||||
| // Stateful = stateful, | |||||
| // Unroll = unroll, | |||||
| // TimeMajor = time_major | |||||
| // }); | |||||
| //public RNN New(List<IRnnArgCell> cell, | |||||
| // bool return_sequences = false, | |||||
| // bool return_state = false, | |||||
| // bool go_backwards = false, | |||||
| // bool stateful = false, | |||||
| // bool unroll = false, | |||||
| // bool time_major = false) | |||||
| // => new RNN(new RNNArgs | |||||
| // { | |||||
| // Cell = cell, | |||||
| // ReturnSequences = return_sequences, | |||||
| // ReturnState = return_state, | |||||
| // GoBackwards = go_backwards, | |||||
| // Stateful = stateful, | |||||
| // Unroll = unroll, | |||||
| // TimeMajor = time_major | |||||
| // }); | |||||
| protected Tensors get_initial_state(Tensor inputs) | |||||
| { | |||||
| Type type = cell.GetType(); | |||||
| MethodInfo MethodInfo = type.GetMethod("get_initial_state"); | |||||
| if (nest.is_nested(inputs)) | |||||
| { | |||||
| // The input are nested sequences. Use the first element in the seq | |||||
| // to get batch size and dtype. | |||||
| inputs = nest.flatten(inputs)[0]; | |||||
| } | |||||
| protected Tensor get_initial_state(Tensor inputs) | |||||
| { | |||||
| return _generate_zero_filled_state_for_cell(null, null); | |||||
| var input_shape = tf.shape(inputs); | |||||
| var batch_size = args.TimeMajor ? input_shape[1] : input_shape[0]; | |||||
| var dtype = inputs.dtype; | |||||
| Tensor init_state; | |||||
| if (MethodInfo != null) | |||||
| { | |||||
| init_state = (Tensor)MethodInfo.Invoke(cell, new object[] { null, batch_size, dtype }); | |||||
| } | |||||
| else | |||||
| { | |||||
| init_state = RNNUtils.generate_zero_filled_state(batch_size, cell.state_size, dtype); | |||||
| } | |||||
| //if (!nest.is_nested(init_state)) | |||||
| //{ | |||||
| // init_state = new List<Tensor> { init_state}; | |||||
| //} | |||||
| return new List<Tensor> { init_state }; | |||||
| //return _generate_zero_filled_state_for_cell(null, null); | |||||
| } | } | ||||
| Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size) | Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size) | ||||
| @@ -0,0 +1,59 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Util; | |||||
| using OneOf; | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.Layers.Rnn | |||||
| { | |||||
| public class RNNUtils | |||||
| { | |||||
| public static Tensor generate_zero_filled_state(Tensor batch_size_tensor, StateSizeWrapper state_size, TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| { | |||||
| if (batch_size_tensor == null || dtype == null) | |||||
| { | |||||
| throw new ValueError( | |||||
| "batch_size and dtype cannot be None while constructing initial " + | |||||
| $"state. Received: batch_size={batch_size_tensor}, dtype={dtype}"); | |||||
| } | |||||
| Func<StateSizeWrapper, Tensor> create_zeros; | |||||
| create_zeros = (StateSizeWrapper unnested_state_size) => | |||||
| { | |||||
| var flat_dims = unnested_state_size.state_size; | |||||
| //if (unnested_state_size is int[]) | |||||
| //{ | |||||
| // flat_dims = new Shape(unnested_state_size.AsT0).as_int_list(); | |||||
| //} | |||||
| //else if (unnested_state_size.IsT1) | |||||
| //{ | |||||
| // flat_dims = new Shape(unnested_state_size.AsT1).as_int_list(); | |||||
| //} | |||||
| var init_state_size = batch_size_tensor.ToArray<int>().concat(flat_dims); | |||||
| return tf.zeros(init_state_size, dtype: dtype); | |||||
| }; | |||||
| //if (nest.is_nested(state_size)) | |||||
| //{ | |||||
| // return nest.map_structure(create_zeros, state_size); | |||||
| //} | |||||
| //else | |||||
| //{ | |||||
| // return create_zeros(state_size); | |||||
| //} | |||||
| return create_zeros(state_size); | |||||
| } | |||||
| public static Tensor generate_zero_filled_state_for_cell(SimpleRNNCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype) | |||||
| { | |||||
| if (inputs != null) | |||||
| { | |||||
| batch_size = tf.shape(inputs)[0]; | |||||
| dtype = inputs.dtype; | |||||
| } | |||||
| return generate_zero_filled_state(batch_size, cell.state_size, dtype); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -4,6 +4,7 @@ using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
| { | { | ||||
| @@ -13,10 +14,23 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| IVariableV1 kernel; | IVariableV1 kernel; | ||||
| IVariableV1 recurrent_kernel; | IVariableV1 recurrent_kernel; | ||||
| IVariableV1 bias; | IVariableV1 bias; | ||||
| DropoutRNNCellMixin DRCMixin; | |||||
| public SimpleRNNCell(SimpleRNNArgs args) : base(args) | public SimpleRNNCell(SimpleRNNArgs args) : base(args) | ||||
| { | { | ||||
| this.args = args; | this.args = args; | ||||
| if (args.Units <= 0) | |||||
| { | |||||
| throw new ValueError( | |||||
| $"units must be a positive integer, got {args.Units}"); | |||||
| } | |||||
| this.args.Dropout = Math.Min(1f, Math.Max(0f, this.args.Dropout)); | |||||
| this.args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this.args.RecurrentDropout)); | |||||
| this.args.state_size = this.args.Units; | |||||
| this.args.output_size = this.args.Units; | |||||
| DRCMixin = new DropoutRNNCellMixin(); | |||||
| DRCMixin.dropout = this.args.Dropout; | |||||
| DRCMixin.recurrent_dropout = this.args.RecurrentDropout; | |||||
| } | } | ||||
| public override void build(KerasShapesWrapper input_shape) | public override void build(KerasShapesWrapper input_shape) | ||||
| @@ -44,7 +58,81 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | ||||
| { | { | ||||
| return base.Call(inputs, initial_state, training); | |||||
| Console.WriteLine($"shape of input: {inputs.shape}"); | |||||
| Tensor states = initial_state[0]; | |||||
| Console.WriteLine($"shape of initial_state: {states.shape}"); | |||||
| var prev_output = nest.is_nested(states) ? states[0] : states; | |||||
| var dp_mask = DRCMixin.get_dropout_maskcell_for_cell(inputs, training.Value); | |||||
| var rec_dp_mask = DRCMixin.get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); | |||||
| Tensor h; | |||||
| var ranks = inputs.rank; | |||||
| //if (dp_mask != null) | |||||
| if(false) | |||||
| { | |||||
| if (ranks > 2) | |||||
| { | |||||
| h = tf.linalg.tensordot(tf.multiply(inputs, dp_mask), kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); | |||||
| } | |||||
| else | |||||
| { | |||||
| h = math_ops.matmul(tf.multiply(inputs, dp_mask), kernel.AsTensor()); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| if (ranks > 2) | |||||
| { | |||||
| h = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); | |||||
| } | |||||
| else | |||||
| { | |||||
| h = math_ops.matmul(inputs, kernel.AsTensor()); | |||||
| } | |||||
| } | |||||
| if (bias != null) | |||||
| { | |||||
| h = tf.nn.bias_add(h, bias); | |||||
| } | |||||
| if (rec_dp_mask != null) | |||||
| { | |||||
| prev_output = tf.multiply(prev_output, rec_dp_mask); | |||||
| } | |||||
| ranks = prev_output.rank; | |||||
| Console.WriteLine($"shape of h: {h.shape}"); | |||||
| Tensor output; | |||||
| if (ranks > 2) | |||||
| { | |||||
| var tmp = tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); | |||||
| output = h + tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } })[0]; | |||||
| } | |||||
| else | |||||
| { | |||||
| output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor())[0]; | |||||
| } | |||||
| Console.WriteLine($"shape of output: {output.shape}"); | |||||
| if (args.Activation != null) | |||||
| { | |||||
| output = args.Activation.Apply(output); | |||||
| } | |||||
| if (nest.is_nested(states)) | |||||
| { | |||||
| return (output, new Tensors { output }); | |||||
| } | |||||
| return (output, output); | |||||
| } | |||||
| public Tensor get_initial_state(Tensors inputs, Tensor batch_size, TF_DataType dtype) | |||||
| { | |||||
| return RNNUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -89,7 +89,7 @@ namespace Tensorflow.Hub | |||||
| } | } | ||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
| protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
| { | { | ||||
| _check_trainability(); | _check_trainability(); | ||||
| @@ -1,60 +0,0 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System.Collections.Generic; | |||||
| using Tensorflow.Keras.Callbacks; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.UnitTest.Callbacks | |||||
| { | |||||
| [TestClass] | |||||
| public class EarlystoppingTest | |||||
| { | |||||
| [TestMethod] | |||||
| // Because loading the weight variable into the model has not yet been implemented, | |||||
| // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. | |||||
| public void Earlystopping() | |||||
| { | |||||
| var layers = keras.layers; | |||||
| var model = keras.Sequential(new List<ILayer> | |||||
| { | |||||
| layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)), | |||||
| layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), | |||||
| layers.MaxPooling2D(), | |||||
| layers.Flatten(), | |||||
| layers.Dense(128, activation: keras.activations.Relu), | |||||
| layers.Dense(10) | |||||
| }); | |||||
| model.summary(); | |||||
| model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), | |||||
| loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), | |||||
| metrics: new[] { "acc" }); | |||||
| var num_epochs = 3; | |||||
| var batch_size = 8; | |||||
| var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); | |||||
| x_train = x_train / 255.0f; | |||||
| // define a CallbackParams first, the parameters you pass al least contain Model and Epochs. | |||||
| CallbackParams callback_parameters = new CallbackParams | |||||
| { | |||||
| Model = model, | |||||
| Epochs = num_epochs, | |||||
| }; | |||||
| // define your earlystop | |||||
| ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); | |||||
| // define a callbcaklist, then add the earlystopping to it. | |||||
| var callbacks = new List<ICallback>(); | |||||
| callbacks.add(earlystop); | |||||
| model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -144,6 +144,17 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
| Assert.AreEqual(expected_output, actual_output); | Assert.AreEqual(expected_output, actual_output); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void SimpleRNNCell() | |||||
| { | |||||
| var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | |||||
| var x = tf.random.normal(new Shape(4, 100)); | |||||
| var cell = keras.layers.SimpleRNNCell(64); | |||||
| var (y, h1) = cell.Apply(inputs:x, state:h0); | |||||
| Assert.AreEqual((4, 64), y.shape); | |||||
| Assert.AreEqual((4, 64), h1[0].shape); | |||||
| } | |||||
| [TestMethod, Ignore("WIP")] | [TestMethod, Ignore("WIP")] | ||||
| public void SimpleRNN() | public void SimpleRNN() | ||||
| { | { | ||||
| @@ -67,4 +67,8 @@ | |||||
| </None> | </None> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | |||||
| <Folder Include="Callbacks\" /> | |||||
| </ItemGroup> | |||||
| </Project> | </Project> | ||||