Browse Source

fix set_shape.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
b5f357a47f
4 changed files with 12 additions and 9 deletions
  1. +1
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  2. +0
    -5
      test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs
  3. +9
    -0
      test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs
  4. +2
    -2
      test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs

+ 1
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -155,8 +155,7 @@ namespace Tensorflow
/// </summary> /// </summary>
public virtual void set_shape(TensorShape shape) public virtual void set_shape(TensorShape shape)
{ {
// this.shape = shape.rank >= 0 ? shape.dims : null;
throw new NotImplementedException("");
this.shape = shape.rank >= 0 ? shape : null;
} }


/// <summary> /// <summary>


+ 0
- 5
test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs View File

@@ -14,8 +14,6 @@ namespace TensorFlowNET.UnitTest.Gradient
[TestMethod] [TestMethod]
public void BroadcastToGrad() public void BroadcastToGrad()
{ {
var graph = tf.Graph().as_default();

var x = tf.constant(2, dtype: dtypes.float32); var x = tf.constant(2, dtype: dtypes.float32);
var y = tf.broadcast_to(x, (2, 4, 3)); var y = tf.broadcast_to(x, (2, 4, 3));
var grad = tf.gradients(y, x); var grad = tf.gradients(y, x);
@@ -30,8 +28,6 @@ namespace TensorFlowNET.UnitTest.Gradient
[TestMethod] [TestMethod]
public void CumsumGrad() public void CumsumGrad()
{ {
var graph = tf.Graph().as_default();

var x = tf.constant(2, dtype: dtypes.float32); var x = tf.constant(2, dtype: dtypes.float32);
var y = tf.broadcast_to(x, (2, 4, 3)); var y = tf.broadcast_to(x, (2, 4, 3));
var z = tf.cumsum(y, axis: 1); var z = tf.cumsum(y, axis: 1);
@@ -47,7 +43,6 @@ namespace TensorFlowNET.UnitTest.Gradient
[TestMethod, Ignore] [TestMethod, Ignore]
public void testGradients() public void testGradients()
{ {
var g = tf.Graph().as_default();
var inp = tf.constant(1.0, shape: new[] { 32, 100 }, name: "in"); var inp = tf.constant(1.0, shape: new[] { 32, 100 }, name: "in");
var w = tf.constant(1.0, shape: new[] { 100, 10 }, name: "w"); var w = tf.constant(1.0, shape: new[] { 100, 10 }, name: "w");
var b = tf.Variable(1.0, shape: new[] { 10 }, name: "b"); var b = tf.Variable(1.0, shape: new[] { 10 }, name: "b");


+ 9
- 0
test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs View File

@@ -1,14 +1,23 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
public class GraphModeTestBase : PythonTest public class GraphModeTestBase : PythonTest
{ {
protected Graph graph;
[TestInitialize] [TestInitialize]
public void TestInit() public void TestInit()
{ {
tf.compat.v1.disable_eager_execution(); tf.compat.v1.disable_eager_execution();
graph = tf.Graph().as_default();
}

[TestCleanup]
public void TestClean()
{
graph.Exit();
} }
} }
} }

+ 2
- 2
test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs View File

@@ -38,7 +38,7 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.AreEqual("scope1/Const_1:0", const3.name); Assert.AreEqual("scope1/Const_1:0", const3.name);
}); });


g.Dispose();
g.Exit();


Assert.AreEqual("", g._name_stack); Assert.AreEqual("", g._name_stack);
} }
@@ -70,7 +70,7 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.AreEqual("scope1/Const_1:0", const3.name); Assert.AreEqual("scope1/Const_1:0", const3.name);
}; };


g.Dispose();
g.Exit();


Assert.AreEqual("", g._name_stack); Assert.AreEqual("", g._name_stack);
} }


Loading…
Cancel
Save