From b5f357a47fb935b496b264921f700fc0989f7e92 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 11 Jul 2021 22:20:36 -0500 Subject: [PATCH] fix set_shape. --- src/TensorFlowNET.Core/Tensors/Tensor.cs | 3 +-- .../GradientTest/GradientTest.cs | 5 ----- test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs | 9 +++++++++ test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs | 4 ++-- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 6eebb523..5166cf81 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -155,8 +155,7 @@ namespace Tensorflow /// public virtual void set_shape(TensorShape shape) { - // this.shape = shape.rank >= 0 ? shape.dims : null; - throw new NotImplementedException(""); + this.shape = shape.rank >= 0 ? shape : null; } /// diff --git a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs index 246488a9..fb561e07 100644 --- a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs @@ -14,8 +14,6 @@ namespace TensorFlowNET.UnitTest.Gradient [TestMethod] public void BroadcastToGrad() { - var graph = tf.Graph().as_default(); - var x = tf.constant(2, dtype: dtypes.float32); var y = tf.broadcast_to(x, (2, 4, 3)); var grad = tf.gradients(y, x); @@ -30,8 +28,6 @@ namespace TensorFlowNET.UnitTest.Gradient [TestMethod] public void CumsumGrad() { - var graph = tf.Graph().as_default(); - var x = tf.constant(2, dtype: dtypes.float32); var y = tf.broadcast_to(x, (2, 4, 3)); var z = tf.cumsum(y, axis: 1); @@ -47,7 +43,6 @@ namespace TensorFlowNET.UnitTest.Gradient [TestMethod, Ignore] public void testGradients() { - var g = tf.Graph().as_default(); var inp = tf.constant(1.0, shape: new[] { 32, 100 }, name: "in"); var w = tf.constant(1.0, shape: new[] { 100, 10 }, name: "w"); var b = tf.Variable(1.0, shape: new[] { 10 }, name: "b"); diff --git a/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs b/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs index e5c46e29..a8bb079e 100644 --- a/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs +++ b/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs @@ -1,14 +1,23 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest { public class GraphModeTestBase : PythonTest { + protected Graph graph; [TestInitialize] public void TestInit() { tf.compat.v1.disable_eager_execution(); + graph = tf.Graph().as_default(); + } + + [TestCleanup] + public void TestClean() + { + graph.Exit(); } } } diff --git a/test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs index 40763ece..253a3259 100644 --- a/test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs @@ -38,7 +38,7 @@ namespace TensorFlowNET.UnitTest.Basics Assert.AreEqual("scope1/Const_1:0", const3.name); }); - g.Dispose(); + g.Exit(); Assert.AreEqual("", g._name_stack); } @@ -70,7 +70,7 @@ namespace TensorFlowNET.UnitTest.Basics Assert.AreEqual("scope1/Const_1:0", const3.name); }; - g.Dispose(); + g.Exit(); Assert.AreEqual("", g._name_stack); }