From b1bd05c1a15b0dcaaebbfce1db973da2aae70e37 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 31 Aug 2019 11:27:34 -0500 Subject: [PATCH] Expose tf.shape() API. --- src/TensorFlowNET.Core/APIs/tf.array.cs | 10 ++++++++++ src/TensorFlowNET.Core/Operations/array_ops.py.cs | 2 +- src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 2 +- .../ImageProcessing/YOLO/common.cs | 7 +++++-- test/TensorFlowNET.UnitTest/ImageTest.cs | 5 +++-- 5 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 793318cb..f8493c40 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -124,5 +124,15 @@ namespace Tensorflow /// A `Tensor`. Has the same type as `input`. public Tensor placeholder_with_default(T input, int[] shape, string name = null) => gen_array_ops.placeholder_with_default(input, shape, name: name); + + /// + /// Returns the shape of a tensor. + /// + /// + /// + /// + /// + public Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) + => array_ops.shape_internal(input, name, optimize: true, out_type: out_type); } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index c85ac245..cf38a8c3 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -338,7 +338,7 @@ namespace Tensorflow public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) => size_internal(input, name, optimize: optimize, out_type: out_type); - private static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) + public static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) { return tf_with(ops.name_scope(name, "Shape", new { input }), scope => { diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 4eb1b4b4..c3576350 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -26,7 +26,7 @@ Docs: https://tensorflownet.readthedocs.io 5. Overload session.run(), make syntax simpler. 6. Add Local Response Normalization. 7. Add tf.image related APIs. -8. Add tf.random_normal, tf.constant, tf.pad. +8. Add tf.random_normal, tf.constant, tf.pad, tf.shape. 9. MultiThread is safe. 7.3 0.11.2.0 diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs index a3ffb147..375d68a0 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs @@ -65,11 +65,14 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO Tensor output = null; if (method == "resize") { - + tf_with(tf.variable_scope(name), delegate + { + var input_shape = tf.shape(input_data); + }); } else if(method == "deconv") { - + throw new NotImplementedException("upsample.deconv"); } return output; diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs index e4f8a835..7f3f4e3a 100644 --- a/test/TensorFlowNET.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.UnitTest/ImageTest.cs @@ -14,10 +14,11 @@ namespace TensorFlowNET.UnitTest [TestClass] public class ImageTest { - string imgPath = "../../../../../data/shasta-daisy.jpg"; + string imgPath = "shasta-daisy.jpg"; Tensor contents; - public ImageTest() + [TestInitialize] + public void Initialize() { imgPath = Path.GetFullPath(imgPath); contents = tf.read_file(imgPath);