Browse Source

fix set_shape.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
b5f357a47f
4 changed files with 12 additions and 9 deletions
  1. +1
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  2. +0
    -5
      test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs
  3. +9
    -0
      test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs
  4. +2
    -2
      test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs

+ 1
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -155,8 +155,7 @@ namespace Tensorflow
/// </summary>
public virtual void set_shape(TensorShape shape)
{
// this.shape = shape.rank >= 0 ? shape.dims : null;
throw new NotImplementedException("");
this.shape = shape.rank >= 0 ? shape : null;
}

/// <summary>


+ 0
- 5
test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs View File

@@ -14,8 +14,6 @@ namespace TensorFlowNET.UnitTest.Gradient
[TestMethod]
public void BroadcastToGrad()
{
var graph = tf.Graph().as_default();

var x = tf.constant(2, dtype: dtypes.float32);
var y = tf.broadcast_to(x, (2, 4, 3));
var grad = tf.gradients(y, x);
@@ -30,8 +28,6 @@ namespace TensorFlowNET.UnitTest.Gradient
[TestMethod]
public void CumsumGrad()
{
var graph = tf.Graph().as_default();

var x = tf.constant(2, dtype: dtypes.float32);
var y = tf.broadcast_to(x, (2, 4, 3));
var z = tf.cumsum(y, axis: 1);
@@ -47,7 +43,6 @@ namespace TensorFlowNET.UnitTest.Gradient
[TestMethod, Ignore]
public void testGradients()
{
var g = tf.Graph().as_default();
var inp = tf.constant(1.0, shape: new[] { 32, 100 }, name: "in");
var w = tf.constant(1.0, shape: new[] { 100, 10 }, name: "w");
var b = tf.Variable(1.0, shape: new[] { 10 }, name: "b");


+ 9
- 0
test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs View File

@@ -1,14 +1,23 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
public class GraphModeTestBase : PythonTest
{
protected Graph graph;
[TestInitialize]
public void TestInit()
{
tf.compat.v1.disable_eager_execution();
graph = tf.Graph().as_default();
}

[TestCleanup]
public void TestClean()
{
graph.Exit();
}
}
}

+ 2
- 2
test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs View File

@@ -38,7 +38,7 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.AreEqual("scope1/Const_1:0", const3.name);
});

g.Dispose();
g.Exit();

Assert.AreEqual("", g._name_stack);
}
@@ -70,7 +70,7 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.AreEqual("scope1/Const_1:0", const3.name);
};

g.Dispose();
g.Exit();

Assert.AreEqual("", g._name_stack);
}


Loading…
Cancel
Save