| @@ -126,16 +126,18 @@ namespace Tensorflow | |||
| public Tensor[] fused_batch_norm(Tensor x, | |||
| IVariableV1 scale, | |||
| IVariableV1 offset, | |||
| Tensor mean = null, | |||
| Tensor variance = null, | |||
| IVariableV1 mean = null, | |||
| IVariableV1 variance = null, | |||
| float epsilon = 0.001f, | |||
| string data_format = "NHWC", | |||
| bool is_training = true, | |||
| string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance, | |||
| string name = null, | |||
| float exponential_avg_factor = 1.0f) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance, | |||
| epsilon: epsilon, | |||
| data_format: data_format, | |||
| is_training: is_training, | |||
| name: name); | |||
| name: name, | |||
| exponential_avg_factor: exponential_avg_factor); | |||
| public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) | |||
| => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); | |||
| @@ -180,7 +180,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| // [DebuggerStepThrough] | |||
| [DebuggerStepThrough] | |||
| public static void tf_with<T>(T py, Action<T> action) where T : ITensorFlowObject | |||
| { | |||
| try | |||
| @@ -91,7 +91,7 @@ namespace Tensorflow.Contexts | |||
| } | |||
| [DebuggerStepThrough] | |||
| public Tensor RunInAutoMode(Func<Tensor> graphAction, Func<Tensor> eagerAction, params Tensor[] tensors) | |||
| public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tensor[] tensors) | |||
| { | |||
| var shouldRunInEager = executing_eagerly() | |||
| && tensors.Count(x => x.IsEagerTensor) == tensors.Length; | |||
| @@ -0,0 +1,15 @@ | |||
| using System; | |||
| using Tensorflow.Gradients; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.tensorflow; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| public partial class EagerRunner | |||
| { | |||
| public bool MustRecordGradient() | |||
| { | |||
| return HasGradientTape(); | |||
| } | |||
| } | |||
| } | |||
| @@ -38,5 +38,7 @@ namespace Tensorflow.Eager | |||
| Tensor[] inputs, | |||
| object[] attrs, | |||
| Tensor[] results); | |||
| bool MustRecordGradient(); | |||
| } | |||
| } | |||
| @@ -52,7 +52,7 @@ namespace Tensorflow.Framework | |||
| { | |||
| var pred_value = tensor_util.constant_value(pred); | |||
| if (pred_value is null) | |||
| return null; | |||
| return pred.eval(new Session(pred.graph)); | |||
| return pred_value; | |||
| } | |||
| @@ -269,21 +269,29 @@ namespace Tensorflow.Operations | |||
| } | |||
| public static Tensor[] fused_batch_norm_grad_v3(FusedBatchNormParams @params) | |||
| { | |||
| var op = tf.OpDefLib._apply_op_helper("FusedBatchNormGradV3", name: @params.Name, args: new | |||
| { | |||
| y_backprop = @params.YBackprop, | |||
| x = @params.X, | |||
| scale = @params.Scale, | |||
| reserve_space_1 = @params.ReserveSpace1, | |||
| reserve_space_2 = @params.ReserveSpace2, | |||
| reserve_space_3 = @params.ReserveSpace3, | |||
| epsilon = @params.Epsilon, | |||
| data_format = @params.DataFormat, | |||
| is_training = @params.IsTraining | |||
| }); | |||
| return op.outputs; | |||
| } | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("FusedBatchNormGradV3", name: @params.Name, | |||
| args: new | |||
| { | |||
| y_backprop = @params.YBackprop, | |||
| x = @params.X, | |||
| scale = @params.Scale, | |||
| reserve_space_1 = @params.ReserveSpace1, | |||
| reserve_space_2 = @params.ReserveSpace2, | |||
| reserve_space_3 = @params.ReserveSpace3, | |||
| epsilon = @params.Epsilon, | |||
| data_format = @params.DataFormat, | |||
| is_training = @params.IsTraining | |||
| }).outputs, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "FusedBatchNormGradV3", @params.Name, | |||
| null, | |||
| @params.YBackprop, @params.X, @params.Scale, | |||
| @params.ReserveSpace1, @params.ReserveSpace2, @params.ReserveSpace3, | |||
| "epsilon", @params.Epsilon, | |||
| "data_format", @params.DataFormat, | |||
| "is_training", @params.IsTraining), | |||
| @params.YBackprop); | |||
| public static Tensor[] fused_batch_norm(Tensor x, | |||
| Tensor scale, | |||
| @@ -313,9 +321,10 @@ namespace Tensorflow.Operations | |||
| public static Tensor[] fused_batch_norm_v3(Tensor x, | |||
| Tensor scale, | |||
| Tensor offset, | |||
| Tensor mean, | |||
| Tensor variance, | |||
| IVariableV1 mean, | |||
| IVariableV1 variance, | |||
| float epsilon = 0.0001f, | |||
| float exponential_avg_factor = 1.0f, | |||
| string data_format = "NHWC", | |||
| bool is_training = true, | |||
| string name = null) | |||
| @@ -328,9 +337,10 @@ namespace Tensorflow.Operations | |||
| x, | |||
| scale, | |||
| offset, | |||
| mean, | |||
| variance, | |||
| mean.AsTensor(), | |||
| variance.AsTensor(), | |||
| "epsilon", epsilon, | |||
| "exponential_avg_factor", exponential_avg_factor, | |||
| "data_format", data_format, | |||
| "is_training", is_training); | |||
| @@ -378,14 +388,14 @@ namespace Tensorflow.Operations | |||
| } | |||
| public static Tensor log_softmax(Tensor logits, string name = null) | |||
| { | |||
| var _op = tf.OpDefLib._apply_op_helper("LogSoftmax", name: name, args: new | |||
| { | |||
| logits | |||
| }); | |||
| return _op.output; | |||
| } | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("LogSoftmax", name: name, | |||
| args: new { logits }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "LogSoftmax", name, | |||
| null, | |||
| logits).FirstOrDefault(), | |||
| logits); | |||
| /// <summary> | |||
| /// Says whether the targets are in the top `K` predictions. | |||
| @@ -560,6 +570,16 @@ namespace Tensorflow.Operations | |||
| /// <returns></returns> | |||
| public static (Tensor, Tensor) softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = null) | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "SoftmaxCrossEntropyWithLogits", name, | |||
| null, | |||
| features, labels); | |||
| return (results[0], results[1]); | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("SoftmaxCrossEntropyWithLogits", name: name, args: new | |||
| { | |||
| features, | |||
| @@ -68,7 +68,7 @@ namespace Tensorflow | |||
| null, | |||
| resource, value); | |||
| return null; | |||
| return results.Length == 0 ? null : results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("AssignVariableOp", name, new { resource, value }); | |||
| @@ -99,20 +99,21 @@ namespace Tensorflow | |||
| public static Tensor[] fused_batch_norm(Tensor x, | |||
| IVariableV1 scale, | |||
| IVariableV1 offset, | |||
| Tensor mean, | |||
| Tensor variance, | |||
| IVariableV1 mean, | |||
| IVariableV1 variance, | |||
| float epsilon = 0.001f, | |||
| string data_format = "NHWC", | |||
| bool is_training = true, | |||
| string name = null) | |||
| string name = null, | |||
| float exponential_avg_factor = 1.0f) | |||
| { | |||
| x = ops.convert_to_tensor(x, name: "input"); | |||
| var scale_tensor = ops.convert_to_tensor(scale, name: "scale"); | |||
| var offset_tensor = ops.convert_to_tensor(offset, name: "offset"); | |||
| if (mean == null) | |||
| /*if (mean == null) | |||
| mean = constant_op.constant(new float[0]); | |||
| if (variance == null) | |||
| variance = constant_op.constant(new float[0]); | |||
| variance = constant_op.constant(new float[0]);*/ | |||
| var min_epsilon = 1.001e-5f; | |||
| epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; | |||
| @@ -122,15 +123,16 @@ namespace Tensorflow | |||
| mean, | |||
| variance, | |||
| epsilon, | |||
| data_format, | |||
| is_training, | |||
| name); | |||
| exponential_avg_factor: exponential_avg_factor, | |||
| data_format: data_format, | |||
| is_training: is_training, | |||
| name: name); | |||
| var y = results[0]; | |||
| var batch_mean = results[1]; | |||
| var batch_var = results[2]; | |||
| var running_mean = results[1]; | |||
| var running_var = results[2]; | |||
| return new[] { y, batch_mean, batch_var }; | |||
| return new[] { y, running_mean, running_var }; | |||
| } | |||
| /// <summary> | |||
| @@ -255,7 +255,7 @@ namespace Tensorflow | |||
| // The output cost shape should be the input minus axis. | |||
| var output_shape = array_ops.slice(input_shape, | |||
| new int[] { 0 }, | |||
| new Tensor[] { constant_op.constant(0) }, | |||
| new Tensor[] { math_ops.subtract(input_rank, 1) }); | |||
| cost = array_ops.reshape(cost, output_shape); | |||
| @@ -274,36 +274,38 @@ namespace Tensorflow | |||
| var rank = array_ops.rank(logits); | |||
| var last_dim_size = array_ops.slice(array_ops.shape(logits), | |||
| new[] { math_ops.subtract(rank, 1) }, | |||
| new[] { 1 }); | |||
| new[] { constant_op.constant(1) }); | |||
| var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0); | |||
| var output = array_ops.reshape(logits, ops); | |||
| // Set output shape if known. | |||
| // if not context.executing_eagerly(): | |||
| var shape = logits.TensorShape; | |||
| if (shape != null && shape.ndim > 0) | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| var product = 1; | |||
| var product_valid = true; | |||
| foreach (var d in shape.dims.Take(shape.ndim - 1)) | |||
| var shape = logits.TensorShape; | |||
| if (shape != null && shape.ndim > 0) | |||
| { | |||
| if (d == -1) | |||
| var product = 1; | |||
| var product_valid = true; | |||
| foreach (var d in shape.dims.Take(shape.ndim - 1)) | |||
| { | |||
| product_valid = false; | |||
| break; | |||
| if (d == -1) | |||
| { | |||
| product_valid = false; | |||
| break; | |||
| } | |||
| else | |||
| { | |||
| product *= d; | |||
| } | |||
| } | |||
| else | |||
| if (product_valid) | |||
| { | |||
| product *= d; | |||
| var output_shape = new[] { product }; | |||
| throw new NotImplementedException("_flatten_outer_dims product_valid"); | |||
| } | |||
| } | |||
| if (product_valid) | |||
| { | |||
| var output_shape = new[] { product }; | |||
| throw new NotImplementedException("_flatten_outer_dims product_valid"); | |||
| } | |||
| } | |||
| return output; | |||
| @@ -5,7 +5,7 @@ | |||
| <AssemblyName>TensorFlow.NET</AssemblyName> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <TargetTensorFlow>2.2.0</TargetTensorFlow> | |||
| <Version>0.30.0</Version> | |||
| <Version>0.31.0</Version> | |||
| <LangVersion>8.0</LangVersion> | |||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| @@ -19,7 +19,7 @@ | |||
| <Description>Google's TensorFlow full binding in .NET Standard. | |||
| Building, training and infering deep learning models. | |||
| https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.30.0.0</AssemblyVersion> | |||
| <AssemblyVersion>0.31.0.0</AssemblyVersion> | |||
| <PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x. | |||
| * Eager Mode is added finally. | |||
| @@ -30,7 +30,7 @@ https://tensorflownet.readthedocs.io</Description> | |||
| TensorFlow .NET v0.30 is focused on making more Keras API work including: | |||
| * tf.keras.datasets | |||
| * Building keras model in subclass, functional and sequential api</PackageReleaseNotes> | |||
| <FileVersion>0.30.0.0</FileVersion> | |||
| <FileVersion>0.31.0.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
| <SignAssembly>true</SignAssembly> | |||
| @@ -23,6 +23,7 @@ namespace Tensorflow | |||
| public Graph graph => items.First().graph; | |||
| public bool IsEagerTensor => items.First().IsEagerTensor; | |||
| public bool IsList { get; set; } | |||
| public int Length => items.Length; | |||
| public Tensor this[int index] | |||
| { | |||
| @@ -51,12 +51,13 @@ namespace Tensorflow | |||
| private static NDArray _ConstantValue(Tensor tensor, bool partial) | |||
| { | |||
| if (tensor.op.type == "Const") | |||
| switch (tensor.op.type) | |||
| { | |||
| return MakeNdarray(tensor.op.get_attr("value") as TensorProto); | |||
| case "Const": | |||
| return MakeNdarray(tensor.op.get_attr("value") as TensorProto); | |||
| default: | |||
| return null; | |||
| } | |||
| return null; | |||
| } | |||
| public static NDArray MakeNdarray(TensorProto tensor) | |||
| @@ -83,8 +83,11 @@ namespace Tensorflow | |||
| var assign_op = gen_resource_variable_ops.assign_variable_op( | |||
| handle, value_tensor, name: name); | |||
| if (read_value) | |||
| { | |||
| return gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| // return _lazy_read(assign_op, value_tensor); | |||
| // var variable = _lazy_read(assign_op, value_tensor); | |||
| // return variable; | |||
| } | |||
| return assign_op; | |||
| } | |||
| @@ -111,7 +114,7 @@ namespace Tensorflow | |||
| return result; | |||
| } | |||
| BaseResourceVariable _lazy_read(Operation op, Tensor value) | |||
| IVariableV1 _lazy_read(Operation op, Tensor value) | |||
| { | |||
| variable_accessed(this); | |||
| return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id); | |||
| @@ -6,7 +6,7 @@ namespace Tensorflow | |||
| /// Represents a future for a read of a variable. | |||
| /// Pretends to be the tensor if anyone looks. | |||
| /// </summary> | |||
| public class _UnreadVariable : BaseResourceVariable | |||
| public class _UnreadVariable : BaseResourceVariable, IVariableV1 | |||
| { | |||
| public override string Name => _in_graph_mode ? _parent_op.name : "UnreadVariable"; | |||
| @@ -85,10 +85,12 @@ namespace Tensorflow | |||
| public static Tensor assign_sub(IVariableV1 @ref, | |||
| Tensor value, | |||
| bool use_locking = false, | |||
| string name = null) => gen_state_ops.assign_sub(@ref, | |||
| value, | |||
| use_locking: use_locking, | |||
| name: name); | |||
| string name = null) => @ref.dtype.is_ref_dtype() ? | |||
| gen_state_ops.assign_sub(@ref, | |||
| value, | |||
| use_locking: use_locking, | |||
| name: name) : | |||
| @ref.assign(value, name: name) as Tensor; | |||
| //"""Update 'ref' by adding 'value' to it. | |||
| // | |||
| @@ -335,6 +335,7 @@ namespace Tensorflow.Keras.Engine | |||
| var layer_inputs = node.MapArguments(tensor_dict); | |||
| // Console.WriteLine($"{node.Layer}: {node.Layer.Name}"); | |||
| var outputs = node.Layer.Apply(layer_inputs, is_training: training); | |||
| // Update tensor_dict for next input | |||
| @@ -207,11 +207,11 @@ namespace Tensorflow.Keras.Engine | |||
| })); | |||
| } | |||
| protected virtual void add_update(Tensor[] updates, bool inputs = false) | |||
| /*protected virtual void add_update(Tensor[] updates, bool inputs = false) | |||
| { | |||
| var updates_op = updates.Select(x => x.op).ToArray(); | |||
| this.updates.AddRange(updates_op); | |||
| } | |||
| }*/ | |||
| // Determine layer name (non-unique). | |||
| protected virtual void _init_set_name(string name, bool zero_based = true) | |||
| @@ -60,7 +60,19 @@ namespace Tensorflow.Keras.Engine | |||
| Func<Tensor, Tensor, Tensor> metric_obj = null; | |||
| if (metric == "accuracy" || metric == "acc") | |||
| { | |||
| metric_obj = keras.metrics.sparse_categorical_accuracy; | |||
| var y_t_rank = y_t.rank; | |||
| var y_p_rank = y_p.rank; | |||
| var y_t_last_dim = y_t.shape[^1]; | |||
| var y_p_last_dim = y_p.shape[^1]; | |||
| bool is_binary = y_p_last_dim == 1; | |||
| bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; | |||
| if (is_sparse_categorical) | |||
| metric_obj = keras.metrics.sparse_categorical_accuracy; | |||
| else | |||
| metric_obj = keras.metrics.categorical_accuracy; | |||
| return new MeanMetricWrapper(metric_obj, metric); | |||
| } | |||
| @@ -53,7 +53,6 @@ namespace Tensorflow.Keras.Engine | |||
| stop_training = false; | |||
| _train_counter.assign(0); | |||
| bool first_step = true; | |||
| Console.WriteLine($"Training..."); | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| @@ -65,11 +64,6 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| // callbacks.on_train_batch_begin(step) | |||
| results = step_function(iterator); | |||
| if (first_step) | |||
| { | |||
| Console.WriteLine($"epoch: {epoch}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | |||
| first_step = false; | |||
| } | |||
| } | |||
| Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | |||
| } | |||
| @@ -0,0 +1,30 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace Tensorflow.Keras.Losses | |||
| { | |||
| public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc | |||
| { | |||
| float label_smoothing; | |||
| public CategoricalCrossentropy(bool from_logits = false, | |||
| float label_smoothing = 0, | |||
| string reduction = ReductionV2.AUTO, | |||
| string name = "categorical_crossentropy") : | |||
| base(reduction: reduction, | |||
| name: name, | |||
| from_logits: from_logits) | |||
| { | |||
| this.label_smoothing = label_smoothing; | |||
| } | |||
| public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) | |||
| { | |||
| // Try to adjust the shape so that rank of labels = rank of logits - 1. | |||
| return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits); | |||
| } | |||
| } | |||
| } | |||
| @@ -11,14 +11,18 @@ namespace Tensorflow.Keras.Losses | |||
| protected string reduction; | |||
| protected string name; | |||
| bool _allow_sum_over_batch_size; | |||
| protected bool from_logits = false; | |||
| string _name_scope; | |||
| public string Reduction => reduction; | |||
| public Loss(string reduction = ReductionV2.AUTO, string name = null) | |||
| public Loss(string reduction = ReductionV2.AUTO, | |||
| string name = null, | |||
| bool from_logits = false) | |||
| { | |||
| this.reduction = reduction; | |||
| this.name = name; | |||
| this.from_logits = from_logits; | |||
| _allow_sum_over_batch_size = false; | |||
| } | |||
| @@ -29,7 +33,7 @@ namespace Tensorflow.Keras.Losses | |||
| public Tensor Call(Tensor y_true, Tensor y_pred) | |||
| { | |||
| var losses = Apply(y_true, y_pred); | |||
| var losses = Apply(y_true, y_pred, from_logits: from_logits); | |||
| return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); | |||
| } | |||
| @@ -1,11 +1,15 @@ | |||
| namespace Tensorflow.Keras.Losses | |||
| using Tensorflow.Keras.Utils; | |||
| namespace Tensorflow.Keras.Losses | |||
| { | |||
| public class LossFunctionWrapper : Loss | |||
| { | |||
| public LossFunctionWrapper(string reduction = ReductionV2.AUTO, | |||
| string name = null) | |||
| string name = null, | |||
| bool from_logits = false) | |||
| : base(reduction: reduction, | |||
| name: name) | |||
| name: name, | |||
| from_logits: from_logits) | |||
| { | |||
| } | |||
| } | |||
| @@ -4,5 +4,8 @@ | |||
| { | |||
| public ILossFunc SparseCategoricalCrossentropy(bool from_logits = false) | |||
| => new SparseCategoricalCrossentropy(from_logits: from_logits); | |||
| public ILossFunc CategoricalCrossentropy(bool from_logits = false) | |||
| => new CategoricalCrossentropy(from_logits: from_logits); | |||
| } | |||
| } | |||
| @@ -2,6 +2,12 @@ | |||
| { | |||
| public class MetricsApi | |||
| { | |||
| public Tensor categorical_accuracy(Tensor y_true, Tensor y_pred) | |||
| { | |||
| var eql = math_ops.equal(math_ops.argmax(y_true, -1), math_ops.argmax(y_pred, -1)); | |||
| return math_ops.cast(eql, TF_DataType.TF_FLOAT); | |||
| } | |||
| /// <summary> | |||
| /// Calculates how often predictions matches integer labels. | |||
| /// </summary> | |||