diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs index a6e16cc8..7e18ea81 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -40,6 +40,7 @@ namespace Tensorflow.Eager }*/ } + // Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}"); if (!should_record) return should_record; Tensor[] op_outputs; diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 9c8940f3..e07171da 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Framework.Models; using Tensorflow.Graphs; using static Tensorflow.Binding; @@ -14,6 +15,8 @@ namespace Tensorflow.Functions { public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); IntPtr _handle; + public Tensor[] Outputs; + public TensorSpec[] OutputStructure; public ConcreteFunction(Func func, TF_DataType dtype) { @@ -43,23 +46,38 @@ namespace Tensorflow.Functions var input = tf.placeholder(dtype); var output = func(input); + OutputStructure = output.structure; + var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); _handle = graph.ToGraph(opers, new Operation[] { input }, - new Operation[] { }, + new Operation[] { output.variant_tensor.op }, null); } } - public Tensor Execute(Tensor arg) + public ConcreteFunction(Func func, + TF_DataType[] dtypes, TensorShape[] shapes) { - var result = tf.Runner.TFE_Execute(tf.Context, - tf.Context.DeviceName, - Name, - new[] { arg }, - null, - 1); - return result[0]; + string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; + + // IntPtr func_handle; + using (var graph = new FuncGraph(func_name)) + { + var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); + var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); + var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); + var outputs = func(input1, (input2, input3)); + + Outputs = new[] { outputs.Item1, outputs.Item2 }; + OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; + + var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + _handle = graph.ToGraph(opers, + new Operation[] { input1, input2, input3 }, + new Operation[] { outputs.Item1.op, outputs.Item2.op }, + null); + } } public void Dispose() diff --git a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs index 5ef4214c..13a68530 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs @@ -35,7 +35,7 @@ namespace Tensorflow.Gradients if (!state.op_tape.find(op, out var trace)) continue; - Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); + // Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); state.op_tape.erase(op); var out_gradients = new List(trace.output_tensor_info.Length); diff --git a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs index 35909c85..f988cef4 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine.DataAdapters { @@ -12,10 +13,29 @@ namespace Tensorflow.Keras.Engine.DataAdapters { DataHandlerArgs args; IDataAdapter _adapter; + IDatasetV2 _dataset; + int _inferred_steps; + int _current_step; + int _step_increment; + bool _insufficient_data; + int _steps_per_execution_value; + int _initial_epoch => args.InitialEpoch; + int _epochs => args.Epochs; + IVariableV1 _steps_per_execution; public DataHandler(DataHandlerArgs args) { this.args = args; + if(args.StepsPerExecution == null) + { + _steps_per_execution = tf.Variable(1); + _steps_per_execution_value = 1; + } + else + { + _steps_per_execution = args.StepsPerExecution; + _steps_per_execution_value = args.StepsPerExecution.numpy(); + } _adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs { @@ -30,11 +50,64 @@ namespace Tensorflow.Keras.Engine.DataAdapters UseMultiprocessing = args.UseMultiprocessing, Model = args.Model }); + _dataset = _adapter.GetDataset(); + _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); + _current_step = 0; + _step_increment = args.StepsPerExecution.numpy() - 1; + _insufficient_data = false; } - Tensor _infer_steps(IDatasetV2 dataset) + int _infer_steps(int steps_per_epoch, IDatasetV2 dataset) { + if (steps_per_epoch > -1) + return steps_per_epoch; + + var adapter_steps = _adapter.GetSize(); + if (adapter_steps > -1) + return adapter_steps; + throw new NotImplementedException(""); } + + public IEnumerable<(int, OwnedIterator)> enumerate_epochs() + { + using var ownedIterator = new OwnedIterator(_dataset); + foreach (var epoch in range(_initial_epoch, _epochs)) + { + if (_insufficient_data) + break; + yield return (epoch, ownedIterator); + } + } + + public IEnumerable steps() + { + _current_step = 0; + while(_current_step < _inferred_steps) + { + if (_insufficient_data) + break; + + bool can_run_full_execution = _steps_per_execution_value == 1 + || _inferred_steps < 0 + || _inferred_steps - _current_step >= _steps_per_execution_value; + + if (can_run_full_execution) + { + _step_increment = _steps_per_execution_value - 1; + yield return _current_step; + _current_step += _steps_per_execution_value; + } + else + { + var steps_remaining = _inferred_steps - _current_step; + _steps_per_execution.assign(steps_remaining); + _step_increment = steps_remaining - 1; + yield return _current_step; + _current_step += steps_remaining; + _steps_per_execution.assign(_steps_per_execution_value); + } + } + } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/IDataAdapter.cs b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/IDataAdapter.cs index c50a05b8..41253824 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/IDataAdapter.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/IDataAdapter.cs @@ -18,5 +18,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters /// target labels /// bool CanHandle(Tensor x, Tensor y = null); + IDatasetV2 GetDataset(); + int GetSize(); } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index c67ea3f1..9713694a 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -16,6 +16,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters int _batch_size; int num_samples; int num_full_batches; + IDatasetV2 _dataset; public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) { @@ -32,6 +33,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters indices_dataset = indices_dataset.repeat(); indices_dataset = indices_dataset.map(permutation).prefetch(1); indices_dataset = indices_dataset.flat_map(slice_batch_indices); + _dataset = slice_inputs(indices_dataset, args.X, args.Y); } Tensor permutation(Tensor tensor) @@ -53,13 +55,24 @@ namespace Tensorflow.Keras.Engine.DataAdapters var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch }); first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size }); var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices); - return flat_dataset; } - void slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y) + IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y) { - var dataset = tf.data.Dataset.from_tensor(x, y); + var dataset2 = tf.data.Dataset.from_tensor(x, y).repeat(); + var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); + + dataset = dataset.map((batch, data) => + { + var x = gen_array_ops.gather_v2(data.Item1, batch, 0); + var y = gen_array_ops.gather_v2(data.Item2, batch, 0); + return (x, y); + }); + + dataset = dataset.with_options(new DatasetOptions { }); + + return dataset; } public bool CanHandle(Tensor x, Tensor y = null) @@ -70,5 +83,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters void _process_tensorlike() { } + + public IDatasetV2 GetDataset() + => _dataset; + + public int GetSize() + => _size; } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs b/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs index 9ee4059d..8596f8f4 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs @@ -21,5 +21,20 @@ namespace Tensorflow.Keras.Engine _loss_metric = new Mean(name: "loss"); _built = false; } + + /// + /// Computes the overall loss. + /// + /// + /// + public void Apply(Tensor y_true, Tensor y_pred) + { + + } + + public void Build() + { + + } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs index a768a52b..a22aa44f 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs @@ -51,6 +51,21 @@ namespace Tensorflow.Keras.Engine Model = this, StepsPerExecution = _steps_per_execution }); + + stop_training = false; + _train_counter.assign(0); + + foreach(var (epoch, iterator) in data_handler.enumerate_epochs()) + { + // reset_metrics(); + // callbacks.on_epoch_begin(epoch) + // data_handler.catch_stop_iteration(); + foreach(var step in data_handler.steps()) + { + // callbacks.on_train_batch_begin(step) + step_function(iterator); + } + } } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs new file mode 100644 index 00000000..ebe4f710 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs @@ -0,0 +1,30 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + Tensor step_function(OwnedIterator iterator) + { + var data = iterator.next(); + train_step(data[0], data[1]); + throw new NotImplementedException(""); + } + + /// + /// The logic for one training step. + /// + /// + /// + Tensor train_step(Tensor x, Tensor y) + { + using var tape = tf.GradientTape(); + var y_pred = Apply(x, is_training: true); + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.cs index 7dd2a4e7..0337c467 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.cs @@ -33,11 +33,12 @@ namespace Tensorflow.Keras.Engine IVariableV1 _test_counter; IVariableV1 _predict_counter; bool _base_model_initialized; + bool stop_training; public Model(ModelArgs args) : base(args) { - + _init_batch_counters(); } void _configure_steps_per_execution(int steps_per_execution) diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 20cb81a2..5ab4457b 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -64,6 +64,7 @@ namespace Tensorflow var inferred_from = new Dictionary(); var base_types = new List(); var types = new List(); + string _scope_name = scope; // Perform input type inference foreach (var input_arg in op_def.InputArg) @@ -241,7 +242,7 @@ namespace Tensorflow var op = g.create_op(op_type_name, inputs.ToArray(), output_types.ToArray(), - name: scope, + name: _scope_name, input_types: input_types.ToArray(), attrs: attr_protos, op_def: op_def); diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index 0ab90786..231a18f8 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -471,6 +471,42 @@ namespace Tensorflow throw new NotImplementedException(""); } + /// + /// Creates a dataset that applies `f` to the outputs of `input_dataset`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor parallel_map_dataset_v2(Tensor dataset, Tensor num_parallel_calls, ConcreteFunction f, + TF_DataType[] output_types, TensorShape[] output_shapes, + bool use_inter_op_parallelism = true, + string deterministic = "default", + bool preserve_cardinality = false, + string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "ParallelMapDatasetV2", name, + null, + dataset, new Tensor[0], num_parallel_calls, + "f", f, + "output_types", output_types, + "output_shapes", output_shapes, + "use_inter_op_parallelism", use_inter_op_parallelism, + "deterministic", deterministic, + "preserve_cardinality", preserve_cardinality); + return results[0]; + } + + throw new NotImplementedException(""); + } + /// /// A container for an iterator resource. /// diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index a3b638d3..826de98a 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -739,9 +739,9 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "Range", new { start, limit, delta }), scope => { name = scope; - var start1 = ops.convert_to_tensor(start, name: "start"); - var limit1 = ops.convert_to_tensor(limit, name: "limit"); - var delta1 = ops.convert_to_tensor(delta, name: "delta"); + var start1 = ops.convert_to_tensor(start, name: "start", dtype: dtype); + var limit1 = ops.convert_to_tensor(limit, name: "limit", dtype: dtype); + var delta1 = ops.convert_to_tensor(delta, name: "delta", dtype: dtype); return gen_math_ops.range(start1, limit1, delta1, name); }); diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index 36297a41..7d3fd5e7 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using System; using System.Collections.Generic; @@ -31,24 +32,25 @@ namespace Tensorflow /// public interface IVariableV1 { - public string UniqueId { get; } - public string Name { get; } + string UniqueId { get; } + string Name { get; } /// /// Handle is ref type /// - public Tensor Handle { get; } - public string Device { get; } - public Operation Initializer { get; } - public Operation Op { get; } + Tensor Handle { get; } + string Device { get; } + Operation Initializer { get; } + Operation Op { get; } /// /// GraphElement is a copy of Handle /// - public Tensor GraphElement { get; } - public Graph Graph { get; } - public TF_DataType dtype { get; } - public TensorShape shape { get; } + Tensor GraphElement { get; } + Graph Graph { get; } + TF_DataType dtype { get; } + TensorShape shape { get; } Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true); Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false); + NDArray numpy(); } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index f6de69ca..b14d381c 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using Google.Protobuf; +using NumSharp; using System; using System.Collections.Generic; using System.Linq; @@ -424,5 +425,8 @@ namespace Tensorflow var _op = tf.OpDefLib._apply_op_helper("AssignAdd", name: name, args: new { variable, value, use_locking }); return _op; } + + public NDArray numpy() + => throw new RuntimeError("Graph mode can't use numpy()."); } }