| @@ -18,6 +18,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| public int Inferredsteps => _inferred_steps; | public int Inferredsteps => _inferred_steps; | ||||
| int _current_step; | int _current_step; | ||||
| int _step_increment; | int _step_increment; | ||||
| public int StepIncrement => _step_increment; | |||||
| bool _insufficient_data; | bool _insufficient_data; | ||||
| int _steps_per_execution_value; | int _steps_per_execution_value; | ||||
| int _initial_epoch => args.InitialEpoch; | int _initial_epoch => args.InitialEpoch; | ||||
| @@ -73,7 +74,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| _dataset = _adapter.GetDataset(); | _dataset = _adapter.GetDataset(); | ||||
| _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | ||||
| _current_step = 0; | _current_step = 0; | ||||
| _step_increment = args.StepsPerExecution.numpy() - 1; | |||||
| _step_increment = _steps_per_execution_value - 1; | |||||
| _insufficient_data = false; | _insufficient_data = false; | ||||
| } | } | ||||
| @@ -14,6 +14,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| int _batch_size; | int _batch_size; | ||||
| int num_samples; | int num_samples; | ||||
| int num_full_batches; | int num_full_batches; | ||||
| int _partial_batch_size; | |||||
| public TensorLikeDataAdapter(DataAdapterArgs args) | public TensorLikeDataAdapter(DataAdapterArgs args) | ||||
| { | { | ||||
| @@ -22,9 +23,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| num_samples = args.X.shape[0]; | num_samples = args.X.shape[0]; | ||||
| var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | ||||
| _batch_size = batch_size; | _batch_size = batch_size; | ||||
| _size = Convert.ToInt32(Math.Floor(num_samples / (batch_size + 0f))); | |||||
| _size = num_samples < batch_size ? num_samples % batch_size : num_samples / batch_size; | |||||
| num_full_batches = num_samples / batch_size; | num_full_batches = num_samples / batch_size; | ||||
| var _partial_batch_size = num_samples % batch_size; | |||||
| _partial_batch_size = num_samples % batch_size; | |||||
| var indices_dataset = tf.data.Dataset.range(1); | var indices_dataset = tf.data.Dataset.range(1); | ||||
| indices_dataset = indices_dataset.repeat(args.Epochs); | indices_dataset = indices_dataset.repeat(args.Epochs); | ||||
| @@ -57,6 +58,15 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch }); | var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch }); | ||||
| first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size }); | first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size }); | ||||
| var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices); | var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices); | ||||
| if (_partial_batch_size > 0) | |||||
| { | |||||
| var array = array_ops.slice(indices, | |||||
| new[] { constant_op.constant(num_in_full_batch)}, | |||||
| new[] { constant_op.constant(_partial_batch_size)}); | |||||
| var index_remainder = tf.data.Dataset.from_tensor(array); | |||||
| flat_dataset = flat_dataset.concatenate(index_remainder); | |||||
| } | |||||
| return flat_dataset; | return flat_dataset; | ||||
| } | } | ||||
| @@ -340,7 +340,7 @@ namespace Tensorflow.Keras.Engine | |||||
| tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); | tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); | ||||
| var outputs = node.Layer.Apply(layer_inputs, is_training: training); | var outputs = node.Layer.Apply(layer_inputs, is_training: training); | ||||
| foreach (var output in outputs.Where(x => x != null)) | foreach (var output in outputs.Where(x => x != null)) | ||||
| tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); | |||||
| tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); | |||||
| // Update tensor_dict for next input | // Update tensor_dict for next input | ||||
| foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | ||||
| tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | ||||
| @@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Engine | |||||
| foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
| { | { | ||||
| // callbacks.on_train_batch_begin(step) | // callbacks.on_train_batch_begin(step) | ||||
| var results = step_function(iterator); | |||||
| var results = train_step_function(iterator); | |||||
| var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | ||||
| Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); | Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); | ||||
| } | } | ||||
| @@ -1,7 +1,10 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| @@ -21,7 +24,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// <param name="workers"></param> | /// <param name="workers"></param> | ||||
| /// <param name="use_multiprocessing"></param> | /// <param name="use_multiprocessing"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor predict(Tensor x, | |||||
| public Tensors predict(Tensor x, | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int verbose = 0, | int verbose = 0, | ||||
| int steps = -1, | int steps = -1, | ||||
| @@ -43,7 +46,35 @@ namespace Tensorflow.Keras.Engine | |||||
| StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
| }); | }); | ||||
| throw new NotImplementedException(""); | |||||
| Tensors outputs = null; | |||||
| _predict_counter.assign(0); | |||||
| // callbacks.on_predict_begin() | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
| { | |||||
| foreach(var step in data_handler.steps()) | |||||
| { | |||||
| // callbacks.on_predict_batch_begin(step) | |||||
| var batch_outputs = run_predict_step(iterator); | |||||
| outputs = batch_outputs; | |||||
| var end_step = step + data_handler.StepIncrement; | |||||
| // callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs}) | |||||
| } | |||||
| } | |||||
| // callbacks.on_predict_end() | |||||
| return outputs; | |||||
| } | |||||
| Tensors run_predict_step(OwnedIterator iterator) | |||||
| { | |||||
| var data = iterator.next(); | |||||
| var outputs = predict_step(data[0]); | |||||
| tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1)); | |||||
| return outputs; | |||||
| } | |||||
| Tensors predict_step(Tensor data) | |||||
| { | |||||
| return Apply(data, is_training: false); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| public partial class Model | public partial class Model | ||||
| { | { | ||||
| IEnumerable<(string, Tensor)> step_function(OwnedIterator iterator) | |||||
| IEnumerable<(string, Tensor)> train_step_function(OwnedIterator iterator) | |||||
| { | { | ||||
| var data = iterator.next(); | var data = iterator.next(); | ||||
| var outputs = train_step(data[0], data[1]); | var outputs = train_step(data[0], data[1]); | ||||
| @@ -45,5 +45,11 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| TargetShape = target_shape | TargetShape = target_shape | ||||
| }); | }); | ||||
| public Reshape Reshape(object[] target_shape) | |||||
| => new Reshape(new ReshapeArgs | |||||
| { | |||||
| TargetShapeObjects = target_shape | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| @@ -142,9 +142,6 @@ namespace Tensorflow.Keras.Layers | |||||
| if (use_fused_avg_updates) | if (use_fused_avg_updates) | ||||
| exponential_avg_factor = 1.0f - momentum; | exponential_avg_factor = 1.0f - momentum; | ||||
| var beta = this.beta; | |||||
| var gamma = this.gamma; | |||||
| Func<Tensor[]> _fused_batch_norm_training = () => | Func<Tensor[]> _fused_batch_norm_training = () => | ||||
| { | { | ||||
| return tf.nn.fused_batch_norm( | return tf.nn.fused_batch_norm( | ||||
| @@ -21,11 +21,15 @@ namespace Tensorflow.Keras.Layers | |||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) | ||||
| { | { | ||||
| var shape_tensor = array_ops.shape(inputs); | |||||
| var shape = new List<int> { inputs.shape[0] }; | |||||
| shape.AddRange(args.TargetShape.dims); | |||||
| var shapes = new List<object>(); | |||||
| shapes.Add(array_ops.shape(inputs)[0]); | |||||
| if (args.TargetShapeObjects != null) | |||||
| shapes.AddRange(args.TargetShapeObjects); | |||||
| if (args.TargetShape != null) | |||||
| args.TargetShape.dims.ToList().ForEach(x => shapes.Add(x)); | |||||
| var shape = ops.convert_to_tensor(shapes); | |||||
| var result = array_ops.reshape(inputs, shape.ToArray()); | |||||
| var result = array_ops.reshape(inputs, shape); | |||||
| if (!tf.Context.executing_eagerly()) | if (!tf.Context.executing_eagerly()) | ||||
| result.set_shape(ComputeOutputShape(inputs.shape)); | result.set_shape(ComputeOutputShape(inputs.shape)); | ||||
| return result; | return result; | ||||
| @@ -33,14 +37,16 @@ namespace Tensorflow.Keras.Layers | |||||
| public override TensorShape ComputeOutputShape(TensorShape input_shape) | public override TensorShape ComputeOutputShape(TensorShape input_shape) | ||||
| { | { | ||||
| if (input_shape.dims[0] == -1) | |||||
| if (input_shape.dims[1..].Contains(-1)) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| else | |||||
| { | { | ||||
| input_shape = input_shape.dims[0]; | input_shape = input_shape.dims[0]; | ||||
| var output_shape = input_shape.concatenate(args.TargetShape.dims); | var output_shape = input_shape.concatenate(args.TargetShape.dims); | ||||
| return output_shape; | return output_shape; | ||||
| } | } | ||||
| else | |||||
| throw new NotImplementedException(""); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -6,7 +6,7 @@ | |||||
| <LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
| <RootNamespace>Tensorflow.Keras</RootNamespace> | <RootNamespace>Tensorflow.Keras</RootNamespace> | ||||
| <Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
| <Version>0.3.0</Version> | |||||
| <Version>0.4.0</Version> | |||||
| <Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
| <Product>Keras for .NET</Product> | <Product>Keras for .NET</Product> | ||||
| <Copyright>Apache 2.0, Haiping Chen 2020</Copyright> | <Copyright>Apache 2.0, Haiping Chen 2020</Copyright> | ||||
| @@ -20,7 +20,8 @@ | |||||
| * Support Conv2D functional API. | * Support Conv2D functional API. | ||||
| * Support BatchNormalization layer. | * Support BatchNormalization layer. | ||||
| * Building keras model in subclass, functional and sequential api | * Building keras model in subclass, functional and sequential api | ||||
| * Implemented backward_function.</PackageReleaseNotes> | |||||
| * Implemented backward_function. | |||||
| * Support model.load_weights.</PackageReleaseNotes> | |||||
| <Description>Keras for .NET | <Description>Keras for .NET | ||||
| Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages.</Description> | Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages.</Description> | ||||
| @@ -31,8 +32,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
| <RepositoryType>Git</RepositoryType> | <RepositoryType>Git</RepositoryType> | ||||
| <SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
| <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | ||||
| <AssemblyVersion>0.3.0.0</AssemblyVersion> | |||||
| <FileVersion>0.3.0.0</FileVersion> | |||||
| <AssemblyVersion>0.4.0.0</AssemblyVersion> | |||||
| <FileVersion>0.4.0.0</FileVersion> | |||||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| @@ -48,7 +49,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.3" /> | <PackageReference Include="Newtonsoft.Json" Version="12.0.3" /> | ||||
| <PackageReference Include="NumSharp.Lite" Version="0.1.10" /> | |||||
| <PackageReference Include="SciSharp.Keras.HDF5" Version="1.1.10.500" /> | <PackageReference Include="SciSharp.Keras.HDF5" Version="1.1.10.500" /> | ||||
| <PackageReference Include="SharpZipLib" Version="1.3.1" /> | <PackageReference Include="SharpZipLib" Version="1.3.1" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||