| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Tensorflow.NumPy; | |||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -42,7 +43,6 @@ namespace Tensorflow | |||||
| public Tensor multiply(Tensor x, Tensor y, string name = null) | public Tensor multiply(Tensor x, Tensor y, string name = null) | ||||
| => math_ops.multiply(x, y, name: name); | => math_ops.multiply(x, y, name: name); | ||||
| public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) | public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) | ||||
| => math_ops.div_no_nan(a, b); | => math_ops.div_no_nan(a, b); | ||||
| @@ -452,7 +452,18 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| => gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | => gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | ||||
| /// <summary> | |||||
| /// return scalar product | |||||
| /// </summary> | |||||
| /// <typeparam name="Tx"></typeparam> | |||||
| /// <typeparam name="Ty"></typeparam> | |||||
| /// <param name="x"></param> | |||||
| /// <param name="y"></param> | |||||
| /// <param name="axes"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor dot_prod<Tx, Ty>(Tx x, Ty y, NDArray axes, string name = null) | |||||
| => math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name); | |||||
| public Tensor negative(Tensor x, string name = null) | public Tensor negative(Tensor x, string name = null) | ||||
| => gen_math_ops.neg(x, name); | => gen_math_ops.neg(x, name); | ||||
| @@ -486,7 +486,28 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| } | } | ||||
| public static NDArray GetFlattenArray(NDArray x) | |||||
| { | |||||
| switch (x.GetDataType()) | |||||
| { | |||||
| case TF_DataType.TF_FLOAT: | |||||
| x = x.ToArray<float>(); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| x = x.ToArray<double>(); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| case TF_DataType.TF_INT32: | |||||
| x = x.ToArray<int>(); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| x = x.ToArray<long>(); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| return x; | |||||
| } | |||||
| public static TF_DataType GetDataType(this object data) | public static TF_DataType GetDataType(this object data) | ||||
| { | { | ||||
| var type = data.GetType(); | var type = data.GetType(); | ||||
| @@ -60,7 +60,7 @@ public interface IModel : ILayer | |||||
| bool skip_mismatch = false, | bool skip_mismatch = false, | ||||
| object options = null); | object options = null); | ||||
| Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||||
| Dictionary<string, float> evaluate(Tensor x, Tensor y, | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int verbose = 1, | int verbose = 1, | ||||
| int steps = -1, | int steps = -1, | ||||
| @@ -49,9 +49,30 @@ namespace Tensorflow.NumPy | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray prod<T>(params T[] array) where T : unmanaged | public static NDArray prod<T>(params T[] array) where T : unmanaged | ||||
| => new NDArray(tf.reduce_prod(new NDArray(array))); | => new NDArray(tf.reduce_prod(new NDArray(array))); | ||||
| [AutoNumPy] | |||||
| public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null) | |||||
| { | |||||
| //if axes mentioned | |||||
| if (axes != null) | |||||
| { | |||||
| return new NDArray(tf.dot_prod(x1, x2, axes, name)); | |||||
| } | |||||
| if (x1.shape.ndim > 1) | |||||
| { | |||||
| x1 = GetFlattenArray(x1); | |||||
| } | |||||
| if (x2.shape.ndim > 1) | |||||
| { | |||||
| x2 = GetFlattenArray(x2); | |||||
| } | |||||
| //if axes not mentioned, default 0,0 | |||||
| return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name)); | |||||
| } | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y)); | public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y)); | ||||
| [AutoNumPy] | |||||
| public static NDArray square(NDArray x) => new NDArray(tf.square(x)); | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x)); | public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x)); | ||||
| @@ -226,62 +226,62 @@ namespace Tensorflow | |||||
| } | } | ||||
| #region Explicit Conversions | #region Explicit Conversions | ||||
| public unsafe static explicit operator bool(Tensors tensor) | |||||
| public static explicit operator bool(Tensors tensor) | |||||
| { | { | ||||
| return (bool)tensor.Single; | return (bool)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator sbyte(Tensors tensor) | |||||
| public static explicit operator sbyte(Tensors tensor) | |||||
| { | { | ||||
| return (sbyte)tensor.Single; | return (sbyte)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator byte(Tensors tensor) | |||||
| public static explicit operator byte(Tensors tensor) | |||||
| { | { | ||||
| return (byte)tensor.Single; | return (byte)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator ushort(Tensors tensor) | |||||
| public static explicit operator ushort(Tensors tensor) | |||||
| { | { | ||||
| return (ushort)tensor.Single; | return (ushort)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator short(Tensors tensor) | |||||
| public static explicit operator short(Tensors tensor) | |||||
| { | { | ||||
| return (short)tensor.Single; | return (short)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator int(Tensors tensor) | |||||
| public static explicit operator int(Tensors tensor) | |||||
| { | { | ||||
| return (int)tensor.Single; | return (int)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator uint(Tensors tensor) | |||||
| public static explicit operator uint(Tensors tensor) | |||||
| { | { | ||||
| return (uint)tensor.Single; | return (uint)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator long(Tensors tensor) | |||||
| public static explicit operator long(Tensors tensor) | |||||
| { | { | ||||
| return (long)tensor.Single; | return (long)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator ulong(Tensors tensor) | |||||
| public static explicit operator ulong(Tensors tensor) | |||||
| { | { | ||||
| return (ulong)tensor.Single; | return (ulong)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator float(Tensors tensor) | |||||
| public static explicit operator float(Tensors tensor) | |||||
| { | { | ||||
| return (byte)tensor.Single; | return (byte)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator double(Tensors tensor) | |||||
| public static explicit operator double(Tensors tensor) | |||||
| { | { | ||||
| return (double)tensor.Single; | return (double)tensor.Single; | ||||
| } | } | ||||
| public unsafe static explicit operator string(Tensors tensor) | |||||
| public static explicit operator string(Tensors tensor) | |||||
| { | { | ||||
| return (string)tensor.Single; | return (string)tensor.Single; | ||||
| } | } | ||||
| @@ -1,14 +1,14 @@ | |||||
| using Tensorflow.NumPy; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Callbacks; | |||||
| using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
| using static Tensorflow.Binding; | |||||
| using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Callbacks; | |||||
| using Tensorflow.NumPy; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| @@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// <param name="use_multiprocessing"></param> | /// <param name="use_multiprocessing"></param> | ||||
| /// <param name="return_dict"></param> | /// <param name="return_dict"></param> | ||||
| /// <param name="is_val"></param> | /// <param name="is_val"></param> | ||||
| public Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||||
| public Dictionary<string, float> evaluate(Tensor x, Tensor y, | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int verbose = 1, | int verbose = 1, | ||||
| int steps = -1, | int steps = -1, | ||||
| @@ -64,34 +64,11 @@ namespace Tensorflow.Keras.Engine | |||||
| Verbose = verbose, | Verbose = verbose, | ||||
| Steps = data_handler.Inferredsteps | Steps = data_handler.Inferredsteps | ||||
| }); | }); | ||||
| callbacks.on_test_begin(); | |||||
| //Dictionary<string, float>? logs = null; | |||||
| var logs = new Dictionary<string, float>(); | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
| { | |||||
| reset_metrics(); | |||||
| // data_handler.catch_stop_iteration(); | |||||
| foreach (var step in data_handler.steps()) | |||||
| { | |||||
| callbacks.on_test_batch_begin(step); | |||||
| logs = test_function(data_handler, iterator); | |||||
| var end_step = step + data_handler.StepIncrement; | |||||
| if (is_val == false) | |||||
| callbacks.on_test_batch_end(end_step, logs); | |||||
| } | |||||
| } | |||||
| var results = new Dictionary<string, float>(); | |||||
| foreach (var log in logs) | |||||
| { | |||||
| results[log.Key] = log.Value; | |||||
| } | |||||
| return results; | |||||
| return evaluate(data_handler, callbacks, is_val, test_function); | |||||
| } | } | ||||
| public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false) | |||||
| public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false) | |||||
| { | { | ||||
| var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
| { | { | ||||
| @@ -107,34 +84,10 @@ namespace Tensorflow.Keras.Engine | |||||
| Verbose = verbose, | Verbose = verbose, | ||||
| Steps = data_handler.Inferredsteps | Steps = data_handler.Inferredsteps | ||||
| }); | }); | ||||
| callbacks.on_test_begin(); | |||||
| Dictionary<string, float> logs = null; | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
| { | |||||
| reset_metrics(); | |||||
| callbacks.on_epoch_begin(epoch); | |||||
| // data_handler.catch_stop_iteration(); | |||||
| foreach (var step in data_handler.steps()) | |||||
| { | |||||
| callbacks.on_test_batch_begin(step); | |||||
| logs = test_step_multi_inputs_function(data_handler, iterator); | |||||
| var end_step = step + data_handler.StepIncrement; | |||||
| if (is_val == false) | |||||
| callbacks.on_test_batch_end(end_step, logs); | |||||
| } | |||||
| } | |||||
| var results = new Dictionary<string, float>(); | |||||
| foreach (var log in logs) | |||||
| { | |||||
| results[log.Key] = log.Value; | |||||
| } | |||||
| return results; | |||||
| return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function); | |||||
| } | } | ||||
| public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false) | public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false) | ||||
| { | { | ||||
| var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
| @@ -150,9 +103,24 @@ namespace Tensorflow.Keras.Engine | |||||
| Verbose = verbose, | Verbose = verbose, | ||||
| Steps = data_handler.Inferredsteps | Steps = data_handler.Inferredsteps | ||||
| }); | }); | ||||
| return evaluate(data_handler, callbacks, is_val, test_function); | |||||
| } | |||||
| /// <summary> | |||||
| /// Internal bare implementation of evaluate function. | |||||
| /// </summary> | |||||
| /// <param name="data_handler">Interations handling objects</param> | |||||
| /// <param name="callbacks"></param> | |||||
| /// <param name="test_func">The function to be called on each batch of data.</param> | |||||
| /// <param name="is_val">Whether it is validation or test.</param> | |||||
| /// <returns></returns> | |||||
| Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, Tensor[], Dictionary<string, float>> test_func) | |||||
| { | |||||
| callbacks.on_test_begin(); | callbacks.on_test_begin(); | ||||
| Dictionary<string, float> logs = null; | |||||
| var results = new Dictionary<string, float>(); | |||||
| var logs = results; | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
| { | { | ||||
| reset_metrics(); | reset_metrics(); | ||||
| @@ -162,45 +130,47 @@ namespace Tensorflow.Keras.Engine | |||||
| foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
| { | { | ||||
| callbacks.on_test_batch_begin(step); | callbacks.on_test_batch_begin(step); | ||||
| logs = test_function(data_handler, iterator); | |||||
| logs = test_func(data_handler, iterator.next()); | |||||
| tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _train_counter.assign_add(1)); | |||||
| var end_step = step + data_handler.StepIncrement; | var end_step = step + data_handler.StepIncrement; | ||||
| if (is_val == false) | |||||
| if (!is_val) | |||||
| callbacks.on_test_batch_end(end_step, logs); | callbacks.on_test_batch_end(end_step, logs); | ||||
| } | } | ||||
| if (!is_val) | |||||
| callbacks.on_epoch_end(epoch, logs); | |||||
| } | } | ||||
| var results = new Dictionary<string, float>(); | |||||
| foreach (var log in logs) | foreach (var log in logs) | ||||
| { | { | ||||
| results[log.Key] = log.Value; | results[log.Key] = log.Value; | ||||
| } | } | ||||
| return results; | return results; | ||||
| } | } | ||||
| Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) | |||||
| Dictionary<string, float> test_function(DataHandler data_handler, Tensor[] data) | |||||
| { | { | ||||
| var data = iterator.next(); | |||||
| var outputs = test_step(data_handler, data[0], data[1]); | |||||
| tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | |||||
| var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]); | |||||
| var y_pred = Apply(x, training: false); | |||||
| var loss = compiled_loss.Call(y, y_pred); | |||||
| compiled_metrics.update_state(y, y_pred); | |||||
| var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2); | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) | |||||
| Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data) | |||||
| { | { | ||||
| var data = iterator.next(); | |||||
| var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | ||||
| var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | ||||
| tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | ||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y) | |||||
| { | |||||
| (x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||||
| var y_pred = Apply(x, training: false); | |||||
| var loss = compiled_loss.Call(y, y_pred); | |||||
| compiled_metrics.update_state(y, y_pred); | |||||
| return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -266,7 +266,7 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| // Because evaluate calls call_test_batch_end, this interferes with our output on the screen | // Because evaluate calls call_test_batch_end, this interferes with our output on the screen | ||||
| // so we need to pass a is_val parameter to stop on_test_batch_end | // so we need to pass a is_val parameter to stop on_test_batch_end | ||||
| var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); | |||||
| var val_logs = evaluate((Tensor)validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); | |||||
| foreach (var log in val_logs) | foreach (var log in val_logs) | ||||
| { | { | ||||
| logs["val_" + log.Key] = log.Value; | logs["val_" + log.Key] = log.Value; | ||||
| @@ -65,7 +65,34 @@ namespace TensorFlowNET.UnitTest.NumPy | |||||
| var y = np.power(x, 3); | var y = np.power(x, 3); | ||||
| Assert.AreEqual(y, new[] { 0, 1, 8, 27, 64, 125 }); | Assert.AreEqual(y, new[] { 0, 1, 8, 27, 64, 125 }); | ||||
| } | } | ||||
| [TestMethod] | |||||
| [TestMethod] | |||||
| public void square() | |||||
| { | |||||
| var x = np.arange(6); | |||||
| var y = np.square(x); | |||||
| Assert.AreEqual(y, new[] { 0, 1, 4, 9, 16, 25 }); | |||||
| } | |||||
| [TestMethod] | |||||
| public void dotproduct() | |||||
| { | |||||
| var x1 = new NDArray(new[] { 1, 2, 3 }); | |||||
| var x2 = new NDArray(new[] { 4, 5, 6 }); | |||||
| double result1 = np.dot(x1, x2); | |||||
| NDArray y1 = new float[,] { | |||||
| { 1.0f, 2.0f, 3.0f }, | |||||
| { 4.0f, 5.1f,6.0f }, | |||||
| { 4.0f, 5.1f,6.0f } | |||||
| }; | |||||
| NDArray y2 = new float[,] { | |||||
| { 3.0f, 2.0f, 1.0f }, | |||||
| { 6.0f, 5.1f, 4.0f }, | |||||
| { 6.0f, 5.1f, 4.0f } | |||||
| }; | |||||
| double result2 = np.dot(y1, y2); | |||||
| Assert.AreEqual(result1, 32); | |||||
| Assert.AreEqual(Math.Round(result2, 2), 158.02); | |||||
| } | |||||
| [TestMethod] | |||||
| public void maximum() | public void maximum() | ||||
| { | { | ||||
| var x1 = new NDArray(new[,] { { 1, 2, 3 }, { 4, 5.1, 6 } }); | var x1 = new NDArray(new[,] { { 1, 2, 3 }, { 4, 5.1, 6 } }); | ||||