| @@ -13,6 +13,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| { | |||
| DataHandlerArgs args; | |||
| IDataAdapter _adapter; | |||
| public IDataAdapter DataAdapter => _adapter; | |||
| IDatasetV2 _dataset; | |||
| int _inferred_steps; | |||
| int _current_step; | |||
| @@ -20,5 +20,6 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| bool CanHandle(Tensor x, Tensor y = null); | |||
| IDatasetV2 GetDataset(); | |||
| int GetSize(); | |||
| (Tensor, Tensor) Expand1d(Tensor x, Tensor y); | |||
| } | |||
| } | |||
| @@ -89,5 +89,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| public int GetSize() | |||
| => _size; | |||
| public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||
| { | |||
| if (y.TensorShape.ndim == 1) | |||
| y = array_ops.expand_dims(y, axis: -1); | |||
| return (x, y); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,9 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Utils; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -12,6 +14,10 @@ namespace Tensorflow.Keras.Engine | |||
| ILossFunc _user_losses; | |||
| ILossFunc _losses; | |||
| Mean _loss_metric; | |||
| bool _built; | |||
| Tensor[] _per_output_metrics; | |||
| List<Tensor> loss_values; | |||
| List<Tensor> loss_metric_values; | |||
| public LossesContainer(ILossFunc losses, string[] output_names = null) | |||
| : base(output_names) | |||
| @@ -19,6 +25,8 @@ namespace Tensorflow.Keras.Engine | |||
| _user_losses = losses; | |||
| _losses = losses; | |||
| _loss_metric = new Mean(name: "loss"); | |||
| loss_values = new List<Tensor>(); | |||
| loss_metric_values = new List<Tensor>(); | |||
| _built = false; | |||
| } | |||
| @@ -27,14 +35,44 @@ namespace Tensorflow.Keras.Engine | |||
| /// </summary> | |||
| /// <param name="y_true"></param> | |||
| /// <param name="y_pred"></param> | |||
| public void Apply(Tensor y_true, Tensor y_pred) | |||
| public Tensor Call(Tensor y_true, Tensor y_pred) | |||
| { | |||
| Build(y_pred); | |||
| var loss_value = _losses.Call(y_true, y_pred); | |||
| var loss_metric_value = loss_value; | |||
| var batch_dim = array_ops.shape(y_true)[0]; | |||
| /*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE | |||
| || _losses.Reduction == ReductionV2.AUTO) | |||
| loss_value = losses_utils.scale_loss_for_distribution(loss_value);*/ | |||
| loss_values.append(loss_value); | |||
| loss_metric_values.append(loss_metric_value); | |||
| if(loss_values.Count > 0) | |||
| { | |||
| var total_loss_metric_value = math_ops.add_n(loss_metric_values.ToArray()); | |||
| _loss_metric.update_state(total_loss_metric_value, batch_dim); | |||
| // loss_values = losses_utils.cast_losses_to_common_dtype(loss_values); | |||
| var total_loss = math_ops.add_n(loss_values.ToArray()); | |||
| return total_loss; | |||
| } | |||
| else | |||
| { | |||
| // Ok for a model to have no compiled loss. | |||
| return array_ops.zeros(new TensorShape()); | |||
| } | |||
| } | |||
| public void Build() | |||
| public void Build(Tensor y_pred) | |||
| { | |||
| _create_metrics(); | |||
| _built = true; | |||
| } | |||
| void _create_metrics() | |||
| { | |||
| // _per_output_metrics = _output_names.Select(x => null); | |||
| } | |||
| } | |||
| } | |||
| @@ -16,5 +16,18 @@ namespace Tensorflow.Keras.Engine | |||
| _metrics = metrics; | |||
| _built = false; | |||
| } | |||
| public void update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||
| { | |||
| if (!_built) | |||
| Build(); | |||
| _built = true; | |||
| } | |||
| void Build() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -10,6 +10,8 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| public partial class Model | |||
| { | |||
| LossesContainer compiled_loss; | |||
| MetricsContainer compiled_metrics; | |||
| public void compile(string optimizerName, ILossFunc lossName) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| @@ -18,8 +20,8 @@ namespace Tensorflow.Keras.Engine | |||
| public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | |||
| { | |||
| this.optimizer = optimizer; | |||
| var compiled_loss = new LossesContainer(loss, output_names: output_names); | |||
| var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||
| compiled_loss = new LossesContainer(loss, output_names: output_names); | |||
| compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||
| int experimental_steps_per_execution = 1; | |||
| _configure_steps_per_execution(experimental_steps_per_execution); | |||
| @@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Engine | |||
| var val_x = x[new Slice(train_count)]; | |||
| var val_y = y[new Slice(train_count)]; | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = train_x, | |||
| Y = train_y, | |||
| @@ -1,7 +1,10 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Engine | |||
| @@ -11,7 +14,8 @@ namespace Tensorflow.Keras.Engine | |||
| Tensor step_function(OwnedIterator iterator) | |||
| { | |||
| var data = iterator.next(); | |||
| train_step(data[0], data[1]); | |||
| var outputs = train_step(data[0], data[1]); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| @@ -20,11 +24,33 @@ namespace Tensorflow.Keras.Engine | |||
| /// </summary> | |||
| /// <param name="data"></param> | |||
| /// <returns></returns> | |||
| Tensor train_step(Tensor x, Tensor y) | |||
| IEnumerable<(string, Tensor)> train_step(Tensor x, Tensor y) | |||
| { | |||
| (x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||
| using var tape = tf.GradientTape(); | |||
| var y_pred = Apply(x, is_training: true); | |||
| throw new NotImplementedException(""); | |||
| var loss = compiled_loss.Call(y, y_pred); | |||
| // For custom training steps, users can just write: | |||
| // trainable_variables = self.trainable_variables | |||
| // gradients = tape.gradient(loss, trainable_variables) | |||
| // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | |||
| // The _minimize call does a few extra steps unnecessary in most cases, | |||
| // such as loss scaling and gradient clipping. | |||
| _minimize(tape, optimizer, loss, trainable_variables); | |||
| compiled_metrics.update_state(y, y_pred); | |||
| return new[] { ("loss", loss) }; | |||
| } | |||
| void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List<IVariableV1> trainable_variables) | |||
| { | |||
| var gradients = tape.gradient(loss, trainable_variables); | |||
| gradients = optimizer._aggregate_gradients(zip(gradients, trainable_variables)); | |||
| gradients = optimizer._clip_gradients(gradients); | |||
| optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)), | |||
| experimental_aggregate_gradients: false); | |||
| } | |||
| } | |||
| } | |||
| @@ -6,6 +6,7 @@ using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using NumSharp; | |||
| using System.Collections.Generic; | |||
| using System.Data.Common; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -23,7 +24,7 @@ namespace Tensorflow.Keras.Engine | |||
| #pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used | |||
| #pragma warning restore CS0108 // Member hides inherited member; missing new keyword | |||
| ILossFunc loss; | |||
| IOptimizer optimizer; | |||
| OptimizerV2 optimizer; | |||
| IVariableV1 _steps_per_execution; | |||
| protected bool _is_graph_network; | |||
| protected Tensors inputs; | |||
| @@ -34,6 +35,7 @@ namespace Tensorflow.Keras.Engine | |||
| IVariableV1 _predict_counter; | |||
| bool _base_model_initialized; | |||
| bool stop_training; | |||
| DataHandler data_handler; | |||
| public Model(ModelArgs args) | |||
| : base(args) | |||
| @@ -6,5 +6,7 @@ namespace Tensorflow.Keras.Losses | |||
| { | |||
| public interface ILossFunc | |||
| { | |||
| string Reduction { get; } | |||
| Tensor Call(Tensor y_true, Tensor y_pred); | |||
| } | |||
| } | |||
| @@ -1,6 +1,8 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Losses | |||
| { | |||
| @@ -14,6 +16,8 @@ namespace Tensorflow.Keras.Losses | |||
| bool _allow_sum_over_batch_size; | |||
| string _name_scope; | |||
| public string Reduction => reduction; | |||
| public Loss(string reduction = ReductionV2.AUTO, string name = null) | |||
| { | |||
| this.reduction = reduction; | |||
| @@ -21,6 +25,17 @@ namespace Tensorflow.Keras.Losses | |||
| _allow_sum_over_batch_size = false; | |||
| } | |||
| public virtual Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public Tensor Call(Tensor y_true, Tensor y_pred) | |||
| { | |||
| var losses = Apply(y_true, y_pred); | |||
| return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); | |||
| } | |||
| void _set_name_scope() | |||
| { | |||
| _name_scope = name; | |||
| @@ -6,15 +6,11 @@ namespace Tensorflow.Keras.Losses | |||
| { | |||
| public class LossFunctionWrapper : Loss | |||
| { | |||
| Action fn; | |||
| public LossFunctionWrapper(Action fn, | |||
| string reduction = ReductionV2.AUTO, | |||
| public LossFunctionWrapper(string reduction = ReductionV2.AUTO, | |||
| string name = null) | |||
| : base(reduction: reduction, | |||
| name: name) | |||
| { | |||
| this.fn = fn; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,6 +6,9 @@ namespace Tensorflow.Keras.Losses | |||
| { | |||
| public class ReductionV2 | |||
| { | |||
| public const string NONE = "none"; | |||
| public const string AUTO = "auto"; | |||
| public const string SUM_OVER_BATCH_SIZE = "sum_over_batch_size"; | |||
| public const string WEIGHTED_MEAN = "weighted_mean"; | |||
| } | |||
| } | |||
| @@ -1,6 +1,8 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Data; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Losses | |||
| { | |||
| @@ -9,16 +11,27 @@ namespace Tensorflow.Keras.Losses | |||
| public SparseCategoricalCrossentropy(bool from_logits = false, | |||
| string reduction = ReductionV2.AUTO, | |||
| string name = "sparse_categorical_crossentropy") : | |||
| base(sparse_categorical_crossentropy, | |||
| reduction: reduction, | |||
| base(reduction: reduction, | |||
| name: name) | |||
| { | |||
| } | |||
| static void sparse_categorical_crossentropy() | |||
| public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) | |||
| { | |||
| target = tf.cast(target, dtype: TF_DataType.TF_INT64); | |||
| // Try to adjust the shape so that rank of labels = rank of logits - 1. | |||
| var output_shape = array_ops.shape_v2(output); | |||
| var output_rank = output.TensorShape.ndim; | |||
| var target_rank = target.TensorShape.ndim; | |||
| var update_shape = target_rank != output_rank - 1; | |||
| if (update_shape) | |||
| { | |||
| target = array_ops.reshape(target, new int[] { -1 }); | |||
| output = array_ops.reshape(output, new int[] { -1, output_shape[-1].numpy() }); | |||
| } | |||
| return tf.nn.sparse_softmax_cross_entropy_with_logits(target, output); | |||
| } | |||
| } | |||
| } | |||
| @@ -2,6 +2,8 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Metrics | |||
| @@ -13,9 +15,14 @@ namespace Tensorflow.Keras.Metrics | |||
| { | |||
| IVariableV1 total; | |||
| IVariableV1 count; | |||
| string _reduction; | |||
| TF_DataType _dtype; | |||
| public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| : base(name: name, dtype: dtype) | |||
| { | |||
| _reduction = reduction; | |||
| _dtype = dtype; | |||
| total = add_weight("total", initializer: tf.zeros_initializer); | |||
| if (reduction == Reduction.WEIGHTED_MEAN || | |||
| @@ -24,5 +31,36 @@ namespace Tensorflow.Keras.Metrics | |||
| count = add_weight("count", initializer: tf.zeros_initializer); | |||
| } | |||
| } | |||
| public Tensor update_state(Tensor values, Tensor sample_weight = null) | |||
| { | |||
| if(sample_weight != null) | |||
| { | |||
| (values, sample_weight) = losses_utils.squeeze_or_expand_dimensions( | |||
| values, sample_weight: sample_weight); | |||
| sample_weight = math_ops.cast(sample_weight, dtype: values.dtype); | |||
| values = math_ops.multiply(values, sample_weight); | |||
| } | |||
| Tensor update_total_op = null; | |||
| var value_sum = math_ops.reduce_sum(values); | |||
| tf_with(ops.control_dependencies(new[] { value_sum }), ctl => | |||
| { | |||
| var update_total_op = total.assign_add(value_sum); | |||
| }); | |||
| Tensor num_values = null; | |||
| if (_reduction == ReductionV2.WEIGHTED_MEAN) | |||
| { | |||
| if (sample_weight == null) | |||
| num_values = math_ops.cast(array_ops.size(values), _dtype); | |||
| else | |||
| num_values = math_ops.reduce_sum(sample_weight); | |||
| } | |||
| return tf_with(ops.control_dependencies(new[] { update_total_op }), ctl | |||
| => count.assign_add(num_values)); | |||
| } | |||
| } | |||
| } | |||
| @@ -111,11 +111,16 @@ namespace Tensorflow.Keras.Optimizers | |||
| }); | |||
| } | |||
| Tensor[] _aggregate_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars) | |||
| public Tensor[] _aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars) | |||
| { | |||
| return grads_and_vars.Select(x => x.Item1).ToArray(); | |||
| } | |||
| public Tensor[] _clip_gradients(Tensor[] grads) | |||
| { | |||
| return grads; | |||
| } | |||
| protected IVariableV1 get_slot(IVariableV1 var, string slot_name) | |||
| { | |||
| var slot_dict = _slots[var.UniqueId]; | |||
| @@ -0,0 +1,83 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Keras.Losses; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Utils | |||
| { | |||
| public class losses_utils | |||
| { | |||
| public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null) | |||
| { | |||
| if (sample_weight == null) | |||
| sample_weight = tf.constant(1.0f); | |||
| var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight); | |||
| // Apply reduction function to the individual weighted losses. | |||
| var loss = reduce_weighted_loss(weighted_losses, reduction); | |||
| // Convert the result back to the input type. | |||
| // loss = math_ops.cast(loss, losses.dtype); | |||
| return loss; | |||
| } | |||
| public static Tensor scale_losses_by_sample_weight(Tensor losses, Tensor sample_weight) | |||
| { | |||
| // losses = math_ops.cast(losses, dtypes.float32); | |||
| // sample_weight = math_ops.cast(sample_weight, dtypes.float32); | |||
| // Update dimensions of `sample_weight` to match with `losses` if possible. | |||
| // (losses, sample_weight) = squeeze_or_expand_dimensions(losses, sample_weight); | |||
| return math_ops.multiply(losses, sample_weight); | |||
| } | |||
| public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight) | |||
| { | |||
| var weights_shape = sample_weight.TensorShape; | |||
| var weights_rank = weights_shape.ndim; | |||
| if (weights_rank == 0) | |||
| return (y_pred, sample_weight); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) | |||
| { | |||
| if (reduction == ReductionV2.NONE) | |||
| return weighted_losses; | |||
| else | |||
| { | |||
| var loss = math_ops.reduce_sum(weighted_losses); | |||
| if (reduction == ReductionV2.SUM_OVER_BATCH_SIZE) | |||
| loss = _safe_mean(loss, _num_elements(weighted_losses)); | |||
| return loss; | |||
| } | |||
| } | |||
| static Tensor _safe_mean(Tensor losses, Tensor num_present) | |||
| { | |||
| var total_loss = math_ops.reduce_sum(losses); | |||
| return math_ops.div_no_nan(total_loss, num_present, name: "value"); | |||
| } | |||
| static Tensor _num_elements(Tensor losses) | |||
| { | |||
| return tf_with(ops.name_scope("num_elements"), scope => | |||
| { | |||
| return math_ops.cast(array_ops.size(losses, name: scope), dtype: losses.dtype); | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -24,6 +24,9 @@ namespace Tensorflow | |||
| public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, | |||
| string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||
| { | |||
| if (weights == null) | |||
| weights = tf.constant(1.0f); | |||
| return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate | |||
| { | |||
| // Save the `reduction` argument for loss normalization when distributing | |||
| @@ -521,6 +521,9 @@ namespace Tensorflow | |||
| public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) | |||
| => shape_internal(input, name, optimize: true, out_type: out_type); | |||
| public static Tensor shape_v2(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) | |||
| => shape_internal(input, name, optimize: true, out_type: out_type); | |||
| public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | |||
| => size_internal(input, name, optimize: optimize, out_type: out_type); | |||
| @@ -118,10 +118,13 @@ namespace Tensorflow | |||
| /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) | |||
| /// </remarks> | |||
| public static Tensor div_no_nan(Tensor x, Tensor y, string name = null) | |||
| { | |||
| var op = tf.OpDefLib._apply_op_helper("DivNoNan", name: name, args: new { x, y }); | |||
| return op.output; | |||
| } | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("DivNoNan", name: name, new { x, y }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "DivNoNan", name, | |||
| null, | |||
| x, y).FirstOrDefault(), | |||
| x, y); | |||
| /// <summary> | |||
| /// Computes the mean of elements across dimensions of a tensor. | |||