From 01e6a38cf55c9f1201e48279f71a0ae3bcf6a616 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 10 Jul 2021 18:26:28 -0500 Subject: [PATCH] auto cast by dtype when conveting to tensor. --- src/TensorFlowNET.Core/APIs/tf.image.cs | 2 +- src/TensorFlowNET.Core/Binding.Util.cs | 20 ++++-------- src/TensorFlowNET.Core/Numpy/NDArray.cs | 1 - .../Numpy/Numpy.Creation.cs | 4 +-- src/TensorFlowNET.Core/Numpy/Numpy.cs | 2 +- .../Operations/image_ops_impl.cs | 4 ++- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 5 ++- src/TensorFlowNET.Core/ops.cs | 11 ++++--- .../Basics/RandomTest.cs | 16 +++++----- .../MultithreadingTests.cs | 4 +-- test/TensorFlowNET.UnitTest/OperationsTest.cs | 32 +++++++++---------- 11 files changed, 49 insertions(+), 52 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs index b2db6b41..b0c71f71 100644 --- a/src/TensorFlowNET.Core/APIs/tf.image.cs +++ b/src/TensorFlowNET.Core/APIs/tf.image.cs @@ -210,7 +210,7 @@ namespace Tensorflow => image_ops_impl.non_max_suppression_padded(boxes, scores, max_output_size, iou_threshold, score_threshold, pad_to_max_output_size, name, sorted_input, canonicalized_coordinates, tile_size); - public Tensor resize(Tensor image, TensorShape size, string method = ResizeMethod.BILINEAR) + public Tensor resize(Tensor image, Shape size, string method = ResizeMethod.BILINEAR) => image_ops_impl.resize_images_v2(image, size, method: method); public Tensor resize(Tensor image, Tensor size, string method = ResizeMethod.BILINEAR) diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 018c171f..54931059 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -521,24 +521,16 @@ namespace Tensorflow } } - public static unsafe byte[] ToByteArray(Array array) + public static TF_DataType GetDataType(this object data) { - /*var size = array.GetShape().size; - byte[]? bytes = null; - switch (array) + var type = data.GetType(); + switch (data) { - case float[] arr: - var len = new byte[size * sizeof(float)]; - fixed (void* addr = &arr[0]) - System.Buffer.MemoryCopy(addr, dst, bytesize, bytesize); - tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(array.ToArray()); - break; + case Shape shape: + return TF_DataType.TF_INT64; default: - throw new NotImplementedException(""); + return type.as_tf_dtype(); } - - return bytes;*/ - throw new NotImplementedException(""); } } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index 5b493e32..05cc420b 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -41,7 +41,6 @@ namespace Tensorflow.NumPy public NDArray reshape(Shape newshape) => new NDArray(_tensor, newshape); public NDArray astype(Type type) => throw new NotImplementedException(""); public NDArray astype(TF_DataType type) => throw new NotImplementedException(""); - public bool array_equal(NDArray rhs) => throw new NotImplementedException(""); public NDArray ravel() => throw new NotImplementedException(""); public void shuffle(NDArray nd) => throw new NotImplementedException(""); public Array ToMuliDimArray() => throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs index 7fd02f8e..b1251b2a 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs @@ -10,11 +10,11 @@ namespace Tensorflow.NumPy public partial class np { public static NDArray array(Array data) - => new NDArray(tf.constant(data)); + => new NDArray(data); public static NDArray array(params T[] data) where T : unmanaged - => new NDArray(tf.constant(data)); + => new NDArray(data); public static NDArray arange(T end) where T : unmanaged diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.cs b/src/TensorFlowNET.Core/Numpy/Numpy.cs index dc8489e7..1f57c0df 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.cs @@ -51,7 +51,7 @@ namespace Tensorflow.NumPy public static double infinity => double.PositiveInfinity; public static bool array_equal(NDArray a, NDArray b) - => throw new NotImplementedException(""); + => a.Equals(b); public static NDArray concatenate(NDArray[] arrays, int axis = 0) => throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index d319b3c9..849a93c8 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -2229,7 +2229,9 @@ new_height, new_width"); throw new NotImplementedException("resize_images_v2"); }; - return _resize_images_common(images, resize_fn, ops.convert_to_tensor(size), + + var size_tensor = ops.convert_to_tensor(size, dtype: tf.int32); + return _resize_images_common(images, resize_fn, size_tensor, preserve_aspect_ratio: preserve_aspect_ratio, skip_resize_if_same: false, name: name); diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 525509a4..58aa455e 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -181,7 +181,10 @@ namespace Tensorflow if (tensor.GetType() == typeof(EagerTensor)) { - return new TensorShape(tensor.numpy().ToArray()); + if(tensor.dtype == TF_DataType.TF_INT64) + return new TensorShape(tensor.ToArray()); + else + return new TensorShape(tensor.ToArray()); } if (tensor.TensorShape.ndim == 0) diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 6fc79028..0b13a2aa 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -20,7 +20,6 @@ using Tensorflow.NumPy; using System; using System.Collections.Generic; using System.Linq; -using System.Runtime.InteropServices; using System.Threading; using Tensorflow.Contexts; using Tensorflow.Eager; @@ -126,12 +125,9 @@ namespace Tensorflow if (value is EagerTensor eager_tensor) { - if (dtype == TF_DataType.DtInvalid) - dtype = eager_tensor.dtype; - if (tf.executing_eagerly()) { - if (dtype != eager_tensor.dtype) + if (dtype != TF_DataType.DtInvalid && dtype != eager_tensor.dtype) return gen_math_ops.cast(eager_tensor, dtype.as_base_dtype(), name: name); return eager_tensor; } @@ -146,6 +142,7 @@ namespace Tensorflow else if (value is NDArray nd) return nd; + // graph mode Tensor ret = value switch { NDArray nd => constant_op.constant(nd, dtype: dtype, name: name), @@ -165,6 +162,10 @@ namespace Tensorflow _ => constant_op.constant(value, dtype: dtype, name: name) }; + var original_dtype = value.GetDataType(); + if (dtype != TF_DataType.DtInvalid && dtype != original_dtype) + ret = gen_math_ops.cast(ret, dtype.as_base_dtype(), name: name); + return ret; } diff --git a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs index c208c676..57d21a8b 100644 --- a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs @@ -30,8 +30,8 @@ namespace TensorFlowNET.UnitTest.Basics tf.set_random_seed(1234); var a2 = tf.random_uniform(1); var b2 = tf.random_shuffle(tf.constant(initValue)); - Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); - Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); + Assert.AreEqual(a1, a2); + Assert.AreEqual(b1, b2); } /// @@ -53,8 +53,8 @@ namespace TensorFlowNET.UnitTest.Basics tf.set_random_seed(1234); var a2 = tf.random_uniform(1); var b2 = tf.random_shuffle(tf.constant(initValue)); - Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); - Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); + Assert.AreEqual(a1, a2); + Assert.AreEqual(b1, b2); } /// @@ -76,8 +76,8 @@ namespace TensorFlowNET.UnitTest.Basics var a2 = tf.random.normal(1); var b2 = tf.random.truncated_normal(1); - Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); - Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); + Assert.AreEqual(a1, a2); + Assert.AreEqual(b1, b2); } /// @@ -99,8 +99,8 @@ namespace TensorFlowNET.UnitTest.Basics var a2 = tf.random.normal(1, seed:1234); var b2 = tf.random.truncated_normal(1, seed:1234); - Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); - Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); + Assert.AreEqual(a1, a2); + Assert.AreEqual(b1, b2); } } } \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs index 3e56ea66..f1c2e633 100644 --- a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs +++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs @@ -206,7 +206,7 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void SessionRun_InsideSession() { - MultiThreadedUnitTestExecuter.Run(8, Core); + MultiThreadedUnitTestExecuter.Run(1, Core); //the core method void Core(int tid) @@ -220,7 +220,7 @@ namespace TensorFlowNET.UnitTest var math = a1 + a2; var result = sess.run(math); - result[0].GetAtIndex(0).Should().Be(5); + result.GetAtIndex(0).Should().Be(5); } } } diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 08e1bc4a..a85a2f06 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -74,7 +74,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o.array_equal(check)); + Assert.IsTrue(np.array_equal(o, check)); } } @@ -88,7 +88,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o.array_equal(check)); + Assert.IsTrue(np.array_equal(o, check)); } } @@ -102,7 +102,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o.array_equal(check)); + Assert.IsTrue(np.array_equal(o, check)); } b = tf.cumsum(a, exclusive: true); @@ -111,7 +111,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o.array_equal(check)); + Assert.IsTrue(np.array_equal(o, check)); } b = tf.cumsum(a, reverse: true); @@ -120,7 +120,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o.array_equal(check)); + Assert.IsTrue(np.array_equal(o, check)); } b = tf.cumsum(a, exclusive: true, reverse: true); @@ -129,7 +129,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o.array_equal(check)); + Assert.IsTrue(np.array_equal(o, check)); } } @@ -145,7 +145,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o.array_equal(check)); + Assert.IsTrue(np.array_equal(o, check)); } d = tf.cast(tf.logical_not(b), tf.int32); @@ -154,7 +154,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o.array_equal(check)); + Assert.IsTrue(np.array_equal(o, check)); } d = tf.cast(tf.logical_or(b, c), tf.int32); @@ -163,7 +163,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o.array_equal(check)); + Assert.IsTrue(np.array_equal(o, check)); } d = tf.cast(tf.logical_xor(b, c), tf.int32); @@ -172,7 +172,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o.array_equal(check)); + Assert.IsTrue(np.array_equal(o, check)); } } @@ -311,7 +311,7 @@ namespace TensorFlowNET.UnitTest.Basics } // Testing `operator +(Tensor x, double y) - c = tf.reduce_sum(tf.reduce_sum(a + secondFloatVal, 1)); + c = tf.reduce_sum(tf.reduce_sum(a + secondDoubleVal, 1)); using (var sess = tf.Session()) { var o = sess.run(c, @@ -320,7 +320,7 @@ namespace TensorFlowNET.UnitTest.Basics } // Testing `operator +(double x, Tensor y) - c = tf.reduce_sum(tf.reduce_sum(secondFloatVal + a, 1)); + c = tf.reduce_sum(tf.reduce_sum(secondDoubleVal + a, 1)); using (var sess = tf.Session()) { var o = sess.run(c, @@ -486,7 +486,7 @@ namespace TensorFlowNET.UnitTest.Basics } // Testing `operator -(Tensor x, double y) - c = tf.reduce_sum(tf.reduce_sum(a - secondFloatVal, 1)); + c = tf.reduce_sum(tf.reduce_sum(a - secondDoubleVal, 1)); using (var sess = tf.Session()) { var o = sess.run(c, @@ -495,7 +495,7 @@ namespace TensorFlowNET.UnitTest.Basics } // Testing `operator -(double x, Tensor y) - c = tf.reduce_sum(tf.reduce_sum(secondFloatVal - a, 1)); + c = tf.reduce_sum(tf.reduce_sum(secondDoubleVal - a, 1)); using (var sess = tf.Session()) { var o = sess.run(c, @@ -707,7 +707,7 @@ namespace TensorFlowNET.UnitTest.Basics } // Testing `operator *(Tensor x, double y) - c = tf.reduce_sum(tf.reduce_sum(a * secondFloatVal, 1)); + c = tf.reduce_sum(tf.reduce_sum(a * secondDoubleVal, 1)); using (var sess = tf.Session()) { var o = sess.run(c, @@ -716,7 +716,7 @@ namespace TensorFlowNET.UnitTest.Basics } // Testing `operator *(double x, Tensor y) - c = tf.reduce_sum(tf.reduce_sum(firstFloatVal * b, 1)); + c = tf.reduce_sum(tf.reduce_sum(firstDoubleVal * b, 1)); using (var sess = tf.Session()) { var o = sess.run(c,