| @@ -40,6 +40,7 @@ namespace Tensorflow.Eager | |||||
| }*/ | }*/ | ||||
| } | } | ||||
| // Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}"); | |||||
| if (!should_record) return should_record; | if (!should_record) return should_record; | ||||
| Tensor[] op_outputs; | Tensor[] op_outputs; | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Framework.Models; | |||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -14,6 +15,8 @@ namespace Tensorflow.Functions | |||||
| { | { | ||||
| public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | ||||
| IntPtr _handle; | IntPtr _handle; | ||||
| public Tensor[] Outputs; | |||||
| public TensorSpec[] OutputStructure; | |||||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | ||||
| { | { | ||||
| @@ -43,23 +46,38 @@ namespace Tensorflow.Functions | |||||
| var input = tf.placeholder(dtype); | var input = tf.placeholder(dtype); | ||||
| var output = func(input); | var output = func(input); | ||||
| OutputStructure = output.structure; | |||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
| _handle = graph.ToGraph(opers, | _handle = graph.ToGraph(opers, | ||||
| new Operation[] { input }, | new Operation[] { input }, | ||||
| new Operation[] { }, | |||||
| new Operation[] { output.variant_tensor.op }, | |||||
| null); | null); | ||||
| } | } | ||||
| } | } | ||||
| public Tensor Execute(Tensor arg) | |||||
| public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func, | |||||
| TF_DataType[] dtypes, TensorShape[] shapes) | |||||
| { | { | ||||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||||
| tf.Context.DeviceName, | |||||
| Name, | |||||
| new[] { arg }, | |||||
| null, | |||||
| 1); | |||||
| return result[0]; | |||||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||||
| // IntPtr func_handle; | |||||
| using (var graph = new FuncGraph(func_name)) | |||||
| { | |||||
| var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); | |||||
| var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); | |||||
| var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); | |||||
| var outputs = func(input1, (input2, input3)); | |||||
| Outputs = new[] { outputs.Item1, outputs.Item2 }; | |||||
| OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; | |||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
| _handle = graph.ToGraph(opers, | |||||
| new Operation[] { input1, input2, input3 }, | |||||
| new Operation[] { outputs.Item1.op, outputs.Item2.op }, | |||||
| null); | |||||
| } | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| @@ -35,7 +35,7 @@ namespace Tensorflow.Gradients | |||||
| if (!state.op_tape.find(op, out var trace)) | if (!state.op_tape.find(op, out var trace)) | ||||
| continue; | continue; | ||||
| Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||||
| // Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||||
| state.op_tape.erase(op); | state.op_tape.erase(op); | ||||
| var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.Engine.DataAdapters | namespace Tensorflow.Keras.Engine.DataAdapters | ||||
| { | { | ||||
| @@ -12,10 +13,29 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| { | { | ||||
| DataHandlerArgs args; | DataHandlerArgs args; | ||||
| IDataAdapter _adapter; | IDataAdapter _adapter; | ||||
| IDatasetV2 _dataset; | |||||
| int _inferred_steps; | |||||
| int _current_step; | |||||
| int _step_increment; | |||||
| bool _insufficient_data; | |||||
| int _steps_per_execution_value; | |||||
| int _initial_epoch => args.InitialEpoch; | |||||
| int _epochs => args.Epochs; | |||||
| IVariableV1 _steps_per_execution; | |||||
| public DataHandler(DataHandlerArgs args) | public DataHandler(DataHandlerArgs args) | ||||
| { | { | ||||
| this.args = args; | this.args = args; | ||||
| if(args.StepsPerExecution == null) | |||||
| { | |||||
| _steps_per_execution = tf.Variable(1); | |||||
| _steps_per_execution_value = 1; | |||||
| } | |||||
| else | |||||
| { | |||||
| _steps_per_execution = args.StepsPerExecution; | |||||
| _steps_per_execution_value = args.StepsPerExecution.numpy(); | |||||
| } | |||||
| _adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs | _adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs | ||||
| { | { | ||||
| @@ -30,11 +50,64 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| UseMultiprocessing = args.UseMultiprocessing, | UseMultiprocessing = args.UseMultiprocessing, | ||||
| Model = args.Model | Model = args.Model | ||||
| }); | }); | ||||
| _dataset = _adapter.GetDataset(); | |||||
| _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||||
| _current_step = 0; | |||||
| _step_increment = args.StepsPerExecution.numpy() - 1; | |||||
| _insufficient_data = false; | |||||
| } | } | ||||
| Tensor _infer_steps(IDatasetV2 dataset) | |||||
| int _infer_steps(int steps_per_epoch, IDatasetV2 dataset) | |||||
| { | { | ||||
| if (steps_per_epoch > -1) | |||||
| return steps_per_epoch; | |||||
| var adapter_steps = _adapter.GetSize(); | |||||
| if (adapter_steps > -1) | |||||
| return adapter_steps; | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | |||||
| { | |||||
| using var ownedIterator = new OwnedIterator(_dataset); | |||||
| foreach (var epoch in range(_initial_epoch, _epochs)) | |||||
| { | |||||
| if (_insufficient_data) | |||||
| break; | |||||
| yield return (epoch, ownedIterator); | |||||
| } | |||||
| } | |||||
| public IEnumerable<int> steps() | |||||
| { | |||||
| _current_step = 0; | |||||
| while(_current_step < _inferred_steps) | |||||
| { | |||||
| if (_insufficient_data) | |||||
| break; | |||||
| bool can_run_full_execution = _steps_per_execution_value == 1 | |||||
| || _inferred_steps < 0 | |||||
| || _inferred_steps - _current_step >= _steps_per_execution_value; | |||||
| if (can_run_full_execution) | |||||
| { | |||||
| _step_increment = _steps_per_execution_value - 1; | |||||
| yield return _current_step; | |||||
| _current_step += _steps_per_execution_value; | |||||
| } | |||||
| else | |||||
| { | |||||
| var steps_remaining = _inferred_steps - _current_step; | |||||
| _steps_per_execution.assign(steps_remaining); | |||||
| _step_increment = steps_remaining - 1; | |||||
| yield return _current_step; | |||||
| _current_step += steps_remaining; | |||||
| _steps_per_execution.assign(_steps_per_execution_value); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -18,5 +18,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| /// <param name="y">target labels</param> | /// <param name="y">target labels</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| bool CanHandle(Tensor x, Tensor y = null); | bool CanHandle(Tensor x, Tensor y = null); | ||||
| IDatasetV2 GetDataset(); | |||||
| int GetSize(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -16,6 +16,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| int _batch_size; | int _batch_size; | ||||
| int num_samples; | int num_samples; | ||||
| int num_full_batches; | int num_full_batches; | ||||
| IDatasetV2 _dataset; | |||||
| public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) | public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) | ||||
| { | { | ||||
| @@ -32,6 +33,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| indices_dataset = indices_dataset.repeat(); | indices_dataset = indices_dataset.repeat(); | ||||
| indices_dataset = indices_dataset.map(permutation).prefetch(1); | indices_dataset = indices_dataset.map(permutation).prefetch(1); | ||||
| indices_dataset = indices_dataset.flat_map(slice_batch_indices); | indices_dataset = indices_dataset.flat_map(slice_batch_indices); | ||||
| _dataset = slice_inputs(indices_dataset, args.X, args.Y); | |||||
| } | } | ||||
| Tensor permutation(Tensor tensor) | Tensor permutation(Tensor tensor) | ||||
| @@ -53,13 +55,24 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch }); | var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch }); | ||||
| first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size }); | first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size }); | ||||
| var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices); | var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices); | ||||
| return flat_dataset; | return flat_dataset; | ||||
| } | } | ||||
| void slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y) | |||||
| IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y) | |||||
| { | { | ||||
| var dataset = tf.data.Dataset.from_tensor(x, y); | |||||
| var dataset2 = tf.data.Dataset.from_tensor(x, y).repeat(); | |||||
| var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); | |||||
| dataset = dataset.map((batch, data) => | |||||
| { | |||||
| var x = gen_array_ops.gather_v2(data.Item1, batch, 0); | |||||
| var y = gen_array_ops.gather_v2(data.Item2, batch, 0); | |||||
| return (x, y); | |||||
| }); | |||||
| dataset = dataset.with_options(new DatasetOptions { }); | |||||
| return dataset; | |||||
| } | } | ||||
| public bool CanHandle(Tensor x, Tensor y = null) | public bool CanHandle(Tensor x, Tensor y = null) | ||||
| @@ -70,5 +83,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| void _process_tensorlike() | void _process_tensorlike() | ||||
| { | { | ||||
| } | } | ||||
| public IDatasetV2 GetDataset() | |||||
| => _dataset; | |||||
| public int GetSize() | |||||
| => _size; | |||||
| } | } | ||||
| } | } | ||||
| @@ -21,5 +21,20 @@ namespace Tensorflow.Keras.Engine | |||||
| _loss_metric = new Mean(name: "loss"); | _loss_metric = new Mean(name: "loss"); | ||||
| _built = false; | _built = false; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Computes the overall loss. | |||||
| /// </summary> | |||||
| /// <param name="y_true"></param> | |||||
| /// <param name="y_pred"></param> | |||||
| public void Apply(Tensor y_true, Tensor y_pred) | |||||
| { | |||||
| } | |||||
| public void Build() | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -51,6 +51,21 @@ namespace Tensorflow.Keras.Engine | |||||
| Model = this, | Model = this, | ||||
| StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
| }); | }); | ||||
| stop_training = false; | |||||
| _train_counter.assign(0); | |||||
| foreach(var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
| { | |||||
| // reset_metrics(); | |||||
| // callbacks.on_epoch_begin(epoch) | |||||
| // data_handler.catch_stop_iteration(); | |||||
| foreach(var step in data_handler.steps()) | |||||
| { | |||||
| // callbacks.on_train_batch_begin(step) | |||||
| step_function(iterator); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,30 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.Engine | |||||
| { | |||||
| public partial class Model | |||||
| { | |||||
| Tensor step_function(OwnedIterator iterator) | |||||
| { | |||||
| var data = iterator.next(); | |||||
| train_step(data[0], data[1]); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | |||||
| /// The logic for one training step. | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <returns></returns> | |||||
| Tensor train_step(Tensor x, Tensor y) | |||||
| { | |||||
| using var tape = tf.GradientTape(); | |||||
| var y_pred = Apply(x, is_training: true); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -33,11 +33,12 @@ namespace Tensorflow.Keras.Engine | |||||
| IVariableV1 _test_counter; | IVariableV1 _test_counter; | ||||
| IVariableV1 _predict_counter; | IVariableV1 _predict_counter; | ||||
| bool _base_model_initialized; | bool _base_model_initialized; | ||||
| bool stop_training; | |||||
| public Model(ModelArgs args) | public Model(ModelArgs args) | ||||
| : base(args) | : base(args) | ||||
| { | { | ||||
| _init_batch_counters(); | |||||
| } | } | ||||
| void _configure_steps_per_execution(int steps_per_execution) | void _configure_steps_per_execution(int steps_per_execution) | ||||
| @@ -64,6 +64,7 @@ namespace Tensorflow | |||||
| var inferred_from = new Dictionary<string, object>(); | var inferred_from = new Dictionary<string, object>(); | ||||
| var base_types = new List<TF_DataType>(); | var base_types = new List<TF_DataType>(); | ||||
| var types = new List<TF_DataType>(); | var types = new List<TF_DataType>(); | ||||
| string _scope_name = scope; | |||||
| // Perform input type inference | // Perform input type inference | ||||
| foreach (var input_arg in op_def.InputArg) | foreach (var input_arg in op_def.InputArg) | ||||
| @@ -241,7 +242,7 @@ namespace Tensorflow | |||||
| var op = g.create_op(op_type_name, | var op = g.create_op(op_type_name, | ||||
| inputs.ToArray(), | inputs.ToArray(), | ||||
| output_types.ToArray(), | output_types.ToArray(), | ||||
| name: scope, | |||||
| name: _scope_name, | |||||
| input_types: input_types.ToArray(), | input_types: input_types.ToArray(), | ||||
| attrs: attr_protos, | attrs: attr_protos, | ||||
| op_def: op_def); | op_def: op_def); | ||||
| @@ -471,6 +471,42 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Creates a dataset that applies `f` to the outputs of `input_dataset`. | |||||
| /// </summary> | |||||
| /// <param name="dataset"></param> | |||||
| /// <param name="num_parallel_calls"></param> | |||||
| /// <param name="f"></param> | |||||
| /// <param name="output_types"></param> | |||||
| /// <param name="output_shapes"></param> | |||||
| /// <param name="use_inter_op_parallelism"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor parallel_map_dataset_v2(Tensor dataset, Tensor num_parallel_calls, ConcreteFunction f, | |||||
| TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| bool use_inter_op_parallelism = true, | |||||
| string deterministic = "default", | |||||
| bool preserve_cardinality = false, | |||||
| string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "ParallelMapDatasetV2", name, | |||||
| null, | |||||
| dataset, new Tensor[0], num_parallel_calls, | |||||
| "f", f, | |||||
| "output_types", output_types, | |||||
| "output_shapes", output_shapes, | |||||
| "use_inter_op_parallelism", use_inter_op_parallelism, | |||||
| "deterministic", deterministic, | |||||
| "preserve_cardinality", preserve_cardinality); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// A container for an iterator resource. | /// A container for an iterator resource. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -739,9 +739,9 @@ namespace Tensorflow | |||||
| return tf_with(ops.name_scope(name, "Range", new { start, limit, delta }), scope => | return tf_with(ops.name_scope(name, "Range", new { start, limit, delta }), scope => | ||||
| { | { | ||||
| name = scope; | name = scope; | ||||
| var start1 = ops.convert_to_tensor(start, name: "start"); | |||||
| var limit1 = ops.convert_to_tensor(limit, name: "limit"); | |||||
| var delta1 = ops.convert_to_tensor(delta, name: "delta"); | |||||
| var start1 = ops.convert_to_tensor(start, name: "start", dtype: dtype); | |||||
| var limit1 = ops.convert_to_tensor(limit, name: "limit", dtype: dtype); | |||||
| var delta1 = ops.convert_to_tensor(delta, name: "delta", dtype: dtype); | |||||
| return gen_math_ops.range(start1, limit1, delta1, name); | return gen_math_ops.range(start1, limit1, delta1, name); | ||||
| }); | }); | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using NumSharp; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| @@ -31,24 +32,25 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public interface IVariableV1 | public interface IVariableV1 | ||||
| { | { | ||||
| public string UniqueId { get; } | |||||
| public string Name { get; } | |||||
| string UniqueId { get; } | |||||
| string Name { get; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Handle is ref type | /// Handle is ref type | ||||
| /// </summary> | /// </summary> | ||||
| public Tensor Handle { get; } | |||||
| public string Device { get; } | |||||
| public Operation Initializer { get; } | |||||
| public Operation Op { get; } | |||||
| Tensor Handle { get; } | |||||
| string Device { get; } | |||||
| Operation Initializer { get; } | |||||
| Operation Op { get; } | |||||
| /// <summary> | /// <summary> | ||||
| /// GraphElement is a copy of Handle | /// GraphElement is a copy of Handle | ||||
| /// </summary> | /// </summary> | ||||
| public Tensor GraphElement { get; } | |||||
| public Graph Graph { get; } | |||||
| public TF_DataType dtype { get; } | |||||
| public TensorShape shape { get; } | |||||
| Tensor GraphElement { get; } | |||||
| Graph Graph { get; } | |||||
| TF_DataType dtype { get; } | |||||
| TensorShape shape { get; } | |||||
| Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | ||||
| Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); | ||||
| Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false); | Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false); | ||||
| NDArray numpy(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using NumSharp; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| @@ -424,5 +425,8 @@ namespace Tensorflow | |||||
| var _op = tf.OpDefLib._apply_op_helper("AssignAdd", name: name, args: new { variable, value, use_locking }); | var _op = tf.OpDefLib._apply_op_helper("AssignAdd", name: name, args: new { variable, value, use_locking }); | ||||
| return _op; | return _op; | ||||
| } | } | ||||
| public NDArray numpy() | |||||
| => throw new RuntimeError("Graph mode can't use numpy()."); | |||||
| } | } | ||||
| } | } | ||||