Browse Source

Fix tf.reshape.

tags/v0.20
Oceania2018 5 years ago
parent
commit
c1dcb8ca12
3 changed files with 15 additions and 26 deletions
  1. +4
    -4
      src/TensorFlowNET.Core/APIs/tf.reshape.cs
  2. +0
    -22
      test/TensorFlowNET.UnitTest/Binding/EagerTensorV2Test.cs
  3. +11
    -0
      test/TensorFlowNET.UnitTest/ConstantTest.cs

+ 4
- 4
src/TensorFlowNET.Core/APIs/tf.reshape.cs View File

@@ -18,12 +18,12 @@ namespace Tensorflow
{ {
public partial class tensorflow public partial class tensorflow
{ {
public Tensor reshape<T1, T2>(T1 tensor,
T2 shape,
string name = null) => gen_array_ops.reshape(tensor, shape, name);
public Tensor reshape<T>(T tensor,
TensorShape shape,
string name = null) => gen_array_ops.reshape(tensor, shape, name);


public Tensor reshape(Tensor tensor, public Tensor reshape(Tensor tensor,
int[] shape,
Tensor shape,
string name = null) => gen_array_ops.reshape(tensor, shape, name); string name = null) => gen_array_ops.reshape(tensor, shape, name);
} }
} }

+ 0
- 22
test/TensorFlowNET.UnitTest/Binding/EagerTensorV2Test.cs View File

@@ -1,22 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.UnitTest.Binding
{
[TestClass]
public class EagerTensorV2Test
{
[TestMethod]
public void Creation()
{
var tensor = new EagerTensorV2(new float[,]
{
{ 3.0f, 1.0f },
{ 1.0f, 2.0f }
});
}
}
}

+ 11
- 0
test/TensorFlowNET.UnitTest/ConstantTest.cs View File

@@ -177,5 +177,16 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); Assert.AreEqual(str.Length, Marshal.ReadByte(dst));
//c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); //c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
} }

[TestMethod]
public void Reshape()
{
var ones = tf.ones((3, 2), tf.float32, "ones");
var reshaped = tf.reshape(ones, (2, 3));
Assert.AreEqual(reshaped.dtype, tf.float32);
Assert.AreEqual(reshaped.shape[0], 2);
Assert.AreEqual(reshaped.shape[1], 3);
Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray<float>()));
}
} }
} }

Loading…
Cancel
Save