diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 042ab952..ce6dc4d6 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -21,6 +21,9 @@ namespace Tensorflow public MathApi math { get; } = new MathApi(); public class MathApi { + public Tensor argmax(Tensor input, Axis axis = null, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64) + => gen_math_ops.arg_max(input, axis, name: name, output_type: output_type); + public Tensor log(Tensor x, string name = null) => gen_math_ops.log(x, name); @@ -539,15 +542,12 @@ namespace Tensorflow public Tensor round(Tensor x, string name = null) => gen_math_ops.round(x, name: name); - public Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) + public Tensor cast(Tensor x, TF_DataType dtype, string name = null) => math_ops.cast(x, dtype, name); public Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null) => math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name); - public Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64) - => gen_math_ops.arg_max(input, axis, name: name, output_type: output_type); - public Tensor square(Tensor x, string name = null) => gen_math_ops.square(x, name: name); public Tensor squared_difference(Tensor x, Tensor y, string name = null) diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index c79b6f3a..8a744543 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -549,6 +549,8 @@ namespace Tensorflow return tensors.dtype; case IEnumerable tensors: return tensors.First().dtype; + case RefVariable variable: + return variable.dtype; case ResourceVariable variable: return variable.dtype; default: diff --git a/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs b/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs index ecdcff8e..8710ea5d 100644 --- a/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs +++ b/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Text; using static Tensorflow.Binding; @@ -11,11 +12,13 @@ namespace Tensorflow public object[] OpInputArgs { get; set; } public Dictionary OpAttrs { get; set; } + [DebuggerStepThrough] public ExecuteOpArgs(params object[] inputArgs) { OpInputArgs = inputArgs; } + [DebuggerStepThrough] public ExecuteOpArgs SetAttributes(object attrs) { OpAttrs = ConvertToDict(attrs); diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 3c43686d..05644640 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -56,7 +56,7 @@ namespace Tensorflow => constant_op.constant(axis); public override string ToString() - => $"({string.Join(", ", axis)})"; + => IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; } } diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs b/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs new file mode 100644 index 00000000..244fc61b --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow.NumPy +{ + public class RandomizedImpl + { + [AutoNumPy] + public NDArray permutation(int x) => new NDArray(random_ops.random_shuffle(math_ops.range(0, x))); + + [AutoNumPy] + public NDArray permutation(NDArray x) => new NDArray(random_ops.random_shuffle(x)); + + [AutoNumPy] + public void shuffle(NDArray x) + { + var y = random_ops.random_shuffle(x); + Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize); + } + + public NDArray rand(params int[] shape) + => throw new NotImplementedException(""); + + public NDArray randint(long x) + => throw new NotImplementedException(""); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs index 5e36c48c..160e1d6e 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -45,6 +45,8 @@ namespace Tensorflow.NumPy { if(mask.dtype == TF_DataType.TF_INT32) return GetData(mask.ToArray()); + else if (mask.dtype == TF_DataType.TF_INT64) + return GetData(mask.ToArray().Select(x => Convert.ToInt32(x)).ToArray()); throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs index c657294e..b9ad9812 100644 --- a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs @@ -17,7 +17,13 @@ namespace Tensorflow.NumPy => new NDArray(math_ops.argmax(a, axis ?? -1)); [AutoNumPy] - public static NDArray unique(NDArray a) - => throw new NotImplementedException(""); + public static (NDArray, NDArray) unique(NDArray a) + { + var(u, indice) = array_ops.unique(a); + return (new NDArray(u), new NDArray(indice)); + } + + [AutoNumPy] + public static void shuffle(NDArray x) => np.random.shuffle(x); } } diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs index aa7f8d67..806d38b2 100644 --- a/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs @@ -13,6 +13,6 @@ namespace Tensorflow.NumPy public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.arg_min(x, axis)); [AutoNumPy] - public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.arg_max(x, axis)); + public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis)); } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs index 87658a32..9b539a07 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs @@ -25,6 +25,9 @@ namespace Tensorflow.NumPy public NDArray(byte[] bytes, Shape shape, TF_DataType dtype) : base(bytes, shape, dtype) { NewEagerTensorHandle(); } + public NDArray(long[] value, Shape? shape = null) + : base(value, shape) { NewEagerTensorHandle(); } + public NDArray(IntPtr address, Shape shape, TF_DataType dtype) : base(address, shape, dtype) { NewEagerTensorHandle(); } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index 3f90db00..1adf7c0f 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -42,11 +42,9 @@ namespace Tensorflow.NumPy public NDArray reshape(Shape newshape) => new NDArray(tf.reshape(this, newshape)); public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(this, dtype)); public NDArray ravel() => throw new NotImplementedException(""); - public void shuffle(NDArray nd) => throw new NotImplementedException(""); + public void shuffle(NDArray nd) => np.random.shuffle(nd); public Array ToMuliDimArray() => throw new NotImplementedException(""); public byte[] ToByteArray() => BufferToArray(); - public static string[] AsStringArray(NDArray arr) => throw new NotImplementedException(""); - public override string ToString() => NDArrayRender.ToString(this); } } diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.cs b/src/TensorFlowNET.Core/Numpy/Numpy.cs index 76a00418..89077796 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.cs @@ -75,28 +75,7 @@ namespace Tensorflow.NumPy public static bool allclose(NDArray a, NDArray b, double rtol = 1.0E-5, double atol = 1.0E-8, bool equal_nan = false) => throw new NotImplementedException(""); - public static class random - { - public static NDArray permutation(int x) - { - throw new NotImplementedException(""); - } - - public static void shuffle(NDArray nd) - { - - } - - public static NDArray rand(params int[] shape) - => throw new NotImplementedException(""); - - public static NDArray randint(long x) - => throw new NotImplementedException(""); - - public static NDArray RandomState(int x) - => throw new NotImplementedException(""); - } - + public static RandomizedImpl random = new RandomizedImpl(); public static LinearAlgebraImpl linalg = new LinearAlgebraImpl(); } } diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 46d62d82..3c994a6e 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -193,6 +193,9 @@ namespace Tensorflow case double v: feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v)); break; + case string v: + feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v)); + break; case Array v: feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v, v.GetShape())); break; diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index b55563f2..0dccb955 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -117,7 +117,7 @@ namespace Tensorflow case Shape val: return new EagerTensor(val.dims, new Shape(val.ndim)); case Axis val: - return new EagerTensor(val.axis, new Shape(val.size)); + return new EagerTensor(val.axis, val.IsScalar ? Shape.Scalar : new Shape(val.size)); case string val: return new EagerTensor(new[] { val }, Shape.Scalar); case string[] val: diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 4fa4d773..243b73d3 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -141,7 +141,24 @@ namespace Tensorflow byte[] bytes = nd.ToByteArray(); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); } - else if (!values.GetType().IsArray) + else if (dtype == TF_DataType.TF_STRING && !(values is NDArray)) + { + if (values is string str) + tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); + else if (values is string[] str_values) + tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); + else if (values is byte[] byte_values) + tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); + } + else if (values is Array array) + { + // array + var len = dtype.get_datatype_size() * (int)shape.size; + byte[] bytes = new byte[len]; + System.Buffer.BlockCopy(array, 0, bytes, 0, len); + tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); + } + else { switch (values) { @@ -166,32 +183,10 @@ namespace Tensorflow case double val: tensor_proto.DoubleVal.AddRange(new[] { val }); break; - case string val: - tensor_proto.StringVal.AddRange(val.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString()))); - break; default: throw new Exception("make_tensor_proto Not Implemented"); } } - else if (dtype == TF_DataType.TF_STRING && !(values is NDArray)) - { - if (values is string str) - { - tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); - } - else if (values is string[] str_values) - tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); - else if (values is byte[] byte_values) - tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); - } - else if (values is Array array) - { - // array - var len = dtype.get_datatype_size() * (int)shape.size; - byte[] bytes = new byte[len]; - System.Buffer.BlockCopy(array, 0, bytes, 0, len); - tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); - } return tensor_proto; } diff --git a/src/TensorFlowNET.Core/Training/Saving/Saver.cs b/src/TensorFlowNET.Core/Training/Saving/Saver.cs index c326267f..6138dba4 100644 --- a/src/TensorFlowNET.Core/Training/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Training/Saving/Saver.cs @@ -193,7 +193,7 @@ namespace Tensorflow if (write_state) { - var path = NDArray.AsStringArray(model_checkpoint_path[0])[0]; + var path = model_checkpoint_path[0].StringData()[0]; _RecordLastCheckpoint(path); checkpoint_management.update_checkpoint_state_internal( save_dir: save_path_parent, @@ -211,7 +211,7 @@ namespace Tensorflow export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info); } - return _is_empty ? string.Empty : NDArray.AsStringArray(model_checkpoint_path[0])[0]; + return _is_empty ? string.Empty : model_checkpoint_path[0].StringData()[0]; } public (Saver, object) import_meta_graph(string meta_graph_or_file, diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 2e0c04e0..f499574f 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -165,8 +165,8 @@ namespace Tensorflow if (dtype == TF_DataType.TF_STRING) return ret; - if (dtype != ret.dtype) - ret = gen_math_ops.cast(ret, dtype.as_base_dtype(), name: name); + if (dtype.as_base_dtype() != ret.dtype.as_base_dtype()) + ret = gen_math_ops.cast(ret, dtype, name: name); return ret; } diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs index e406176d..cfcbc4bc 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Preprocessings { @@ -50,8 +51,8 @@ namespace Tensorflow.Keras.Preprocessings if (!seed.HasValue) seed = np.random.randint((long)1e6); var random_index = np.arange(label_list.Count); - var rng = np.random.RandomState(seed.Value); - rng.shuffle(random_index); + tf.set_random_seed(seed.Value); + np.random.shuffle(random_index); var index = random_index.ToArray(); for (int i = 0; i < label_list.Count; i++) diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs index f5b52dfb..fa19987b 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs @@ -115,8 +115,8 @@ namespace Tensorflow.Keras var start_positions = np.arange(0, num_seqs, sequence_stride); if (shuffle) { - var rng = np.random.RandomState(seed); - rng.shuffle(start_positions); + tf.set_random_seed(seed); + np.random.shuffle(start_positions); } var sequence_length_tensor = constant_op.constant(sequence_length, dtype: index_dtype); diff --git a/test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs new file mode 100644 index 00000000..38a4fbbe --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs @@ -0,0 +1,27 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/1.20/reference/random/index.html + /// + [TestClass] + public class RandomizeTest : EagerModeTestBase + { + [TestMethod] + public void permutation() + { + var x = np.random.permutation(10); + Assert.AreEqual(x.shape, 10); + var y = np.random.permutation(x); + Assert.AreEqual(x.shape, 10); + Assert.AreNotEqual(x.ToArray(), y.ToArray()); + } + } +}