Browse Source

Merge pull request #375 from SciSharp/autocast-tensor

Added transparent dtype conversion to feed_dict
tags/v0.12
Haiping GitHub 6 years ago
parent
commit
2c1edcbbce
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1593 additions and 36 deletions
  1. +43
    -36
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +285
    -0
      src/TensorFlowNET.Core/Tensors/TensorConverter.cs
  3. +2
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  4. +57
    -0
      test/TensorFlowNET.UnitTest/SessionTest.cs
  5. +1206
    -0
      test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs

+ 43
- 36
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -107,7 +107,7 @@ namespace Tensorflow
foreach (var subfeed in feed_dict)
{
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
//var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed.Value;
//feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
}
@@ -150,58 +150,64 @@ namespace Tensorflow
int i = 0;
foreach (var x in feed_dict)
{
if (x.Key is Tensor tensor)
if (x.Key is Tensor key)
{
switch (x.Value)
{
case Tensor v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
if (v.dtype != key.dtype)
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}");
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v);
break;
case NDArray v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype));
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype));
break;
case IntPtr v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
var tensor = new Tensor(v);
if (tensor.dtype != key.dtype)
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}");

feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor);
break;
#if _REGEN
// @formatter:off — disable formatter after this line
%types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
%
%types = ["bool", "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
%
// @formatter:on — enable formatter after this line
#else
// @formatter:off — disable formatter after this line
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case bool[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
// @formatter:on — enable formatter after this line
#endif
case bool v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL));
break;

case string v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype));
break;
default:
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}");
@@ -214,6 +220,7 @@ namespace Tensorflow
return _call_tf_sessionrun(feeds, fetches, target_list);
}


private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
{
// Ensure any changes to the graph are reflected in the runtime.


+ 285
- 0
src/TensorFlowNET.Core/Tensors/TensorConverter.cs View File

@@ -0,0 +1,285 @@
using System;
using System.Threading.Tasks;
using NumSharp;
using NumSharp.Backends;
using NumSharp.Utilities;

namespace Tensorflow
{
/// <summary>
/// Provides various methods to conversion between types and <see cref="Tensor"/>.
/// </summary>
public static class TensorConverter
{
/// <summary>
/// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
/// </summary>
/// <param name="nd">The ndarray to convert, can be regular, jagged or multi-dim array.</param>
/// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
/// <exception cref="NotSupportedException"></exception>
public static Tensor ToTensor(NDArray nd, TF_DataType? astype = null)
{
return new Tensor(astype == null ? nd : nd.astype(astype.Value.as_numpy_typecode(), false));
}
/// <summary>
/// Convert given <see cref="NDArray"/> to <see cref="Tensor"/>.
/// </summary>
/// <param name="nd">The ndarray to convert.</param>
/// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
/// <exception cref="NotSupportedException"></exception>
public static Tensor ToTensor(NDArray nd, NPTypeCode? astype = null)
{
return new Tensor(astype == null ? nd : nd.astype(astype.Value, false));
}
/// <summary>
/// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
/// </summary>
/// <param name="array">The array to convert, can be regular, jagged or multi-dim array.</param>
/// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
/// <exception cref="NotSupportedException"></exception>
public static Tensor ToTensor(Array array, TF_DataType? astype = null)
{
if (array == null) throw new ArgumentNullException(nameof(array));
var arrtype = array.ResolveElementType();

var astype_type = astype?.as_numpy_dtype() ?? arrtype;
if (astype_type == arrtype)
{
//no conversion required
if (astype == TF_DataType.TF_STRING)
{
throw new NotSupportedException(); //TODO! when string is fully implemented.
}

if (astype == TF_DataType.TF_INT8)
{
if (array.Rank != 1 || array.GetType().GetElementType()?.IsArray == true) //is multidim or jagged
array = Arrays.Flatten(array);

return new Tensor((sbyte[]) array);
}

//is multidim or jagged, if so - use NDArrays constructor as it records shape.
if (array.Rank != 1 || array.GetType().GetElementType().IsArray)
return new Tensor(new NDArray(array));

#if _REGEN
#region Compute
switch (arrtype)
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: return new Tensor((#2[])arr);
%
default:
throw new NotSupportedException();
}
#endregion
#else

#region Compute

switch (arrtype.GetTypeCode())
{
case NPTypeCode.Boolean: return new Tensor((bool[]) array);
case NPTypeCode.Byte: return new Tensor((byte[]) array);
case NPTypeCode.Int16: return new Tensor((short[]) array);
case NPTypeCode.UInt16: return new Tensor((ushort[]) array);
case NPTypeCode.Int32: return new Tensor((int[]) array);
case NPTypeCode.UInt32: return new Tensor((uint[]) array);
case NPTypeCode.Int64: return new Tensor((long[]) array);
case NPTypeCode.UInt64: return new Tensor((ulong[]) array);
case NPTypeCode.Char: return new Tensor((char[]) array);
case NPTypeCode.Double: return new Tensor((double[]) array);
case NPTypeCode.Single: return new Tensor((float[]) array);
default:
throw new NotSupportedException();
}

#endregion

#endif
} else
{
//conversion is required.
//by this point astype is not null.

//flatten if required
if (array.Rank != 1 || array.GetType().GetElementType()?.IsArray == true) //is multidim or jagged
array = Arrays.Flatten(array);

try
{
return ToTensor(
ArrayConvert.To(array, astype.Value.as_numpy_typecode()),
null
);
} catch (NotSupportedException)
{
//handle dtypes not supported by ArrayConvert
var ret = Array.CreateInstance(astype_type, array.LongLength);
Parallel.For(0, ret.LongLength, i => ret.SetValue(Convert.ChangeType(array.GetValue(i), astype_type), i));
return ToTensor(ret, null);
}
}
}

/// <summary>
/// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
/// </summary>
/// <param name="constant">The constant scalar to convert</param>
/// <param name="astype">Convert <paramref name="constant"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
/// <exception cref="NotSupportedException"></exception>
public static Tensor ToTensor<T>(T constant, TF_DataType? astype = null) where T : unmanaged
{
//was conversion requested?
if (astype == null)
{
//No conversion required
var constantType = typeof(T).as_dtype();
if (constantType == TF_DataType.TF_INT8)
return new Tensor((sbyte) (object) constant);

if (constantType == TF_DataType.TF_STRING)
return new Tensor((string) (object) constant);

#if _REGEN
#region Compute
switch (InfoOf<T>.NPTypeCode)
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: return new Tensor((#2)(object)constant);
%
default:
throw new NotSupportedException();
}
#endregion
#else

#region Compute

switch (InfoOf<T>.NPTypeCode)
{
case NPTypeCode.Boolean: return new Tensor((bool) (object) constant);
case NPTypeCode.Byte: return new Tensor((byte) (object) constant);
case NPTypeCode.Int16: return new Tensor((short) (object) constant);
case NPTypeCode.UInt16: return new Tensor((ushort) (object) constant);
case NPTypeCode.Int32: return new Tensor((int) (object) constant);
case NPTypeCode.UInt32: return new Tensor((uint) (object) constant);
case NPTypeCode.Int64: return new Tensor((long) (object) constant);
case NPTypeCode.UInt64: return new Tensor((ulong) (object) constant);
case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant));
case NPTypeCode.Double: return new Tensor((double) (object) constant);
case NPTypeCode.Single: return new Tensor((float) (object) constant);
default:
throw new NotSupportedException();
}

#endregion
#endif
}

//conversion required

if (astype == TF_DataType.TF_INT8)
return new Tensor(Converts.ToSByte(constant));

if (astype == TF_DataType.TF_STRING)
return new Tensor(Converts.ToString(constant));

var astype_np = astype?.as_numpy_typecode();

#if _REGEN
#region Compute
switch (astype_np)
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: return new Tensor(Converts.To#1(constant));
%
default:
throw new NotSupportedException();
}
#endregion
#else

#region Compute
switch (astype_np)
{
case NPTypeCode.Boolean: return new Tensor(Converts.ToBoolean(constant));
case NPTypeCode.Byte: return new Tensor(Converts.ToByte(constant));
case NPTypeCode.Int16: return new Tensor(Converts.ToInt16(constant));
case NPTypeCode.UInt16: return new Tensor(Converts.ToUInt16(constant));
case NPTypeCode.Int32: return new Tensor(Converts.ToInt32(constant));
case NPTypeCode.UInt32: return new Tensor(Converts.ToUInt32(constant));
case NPTypeCode.Int64: return new Tensor(Converts.ToInt64(constant));
case NPTypeCode.UInt64: return new Tensor(Converts.ToUInt64(constant));
case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant));
case NPTypeCode.Double: return new Tensor(Converts.ToDouble(constant));
case NPTypeCode.Single: return new Tensor(Converts.ToSingle(constant));
default:
throw new NotSupportedException();
}
#endregion
#endif
}

/// <summary>
/// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
/// </summary>
/// <param name="constant">The constant scalar to convert</param>
/// <param name="astype">Convert <paramref name="constant"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
/// <exception cref="NotSupportedException"></exception>
public static Tensor ToTensor(string constant, TF_DataType? astype = null)
{
switch (astype)
{
//was conversion requested?
case null:
case TF_DataType.TF_STRING:
return new Tensor(constant);
//conversion required
case TF_DataType.TF_INT8:
return new Tensor(Converts.ToSByte(constant));
default:
{
var astype_np = astype?.as_numpy_typecode();

#if _REGEN
#region Compute
switch (astype_np)
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: return new Tensor(Converts.To#1(constant));
%
default:
throw new NotSupportedException();
}
#endregion
#else

#region Compute
switch (astype_np)
{
case NPTypeCode.Boolean: return new Tensor(Converts.ToBoolean(constant));
case NPTypeCode.Byte: return new Tensor(Converts.ToByte(constant));
case NPTypeCode.Int16: return new Tensor(Converts.ToInt16(constant));
case NPTypeCode.UInt16: return new Tensor(Converts.ToUInt16(constant));
case NPTypeCode.Int32: return new Tensor(Converts.ToInt32(constant));
case NPTypeCode.UInt32: return new Tensor(Converts.ToUInt32(constant));
case NPTypeCode.Int64: return new Tensor(Converts.ToInt64(constant));
case NPTypeCode.UInt64: return new Tensor(Converts.ToUInt64(constant));
case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant));
case NPTypeCode.Double: return new Tensor(Converts.ToDouble(constant));
case NPTypeCode.Single: return new Tensor(Converts.ToSingle(constant));
default:
throw new NotSupportedException();
}
#endregion
#endif
}
}
}

}
}

+ 2
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -45,6 +45,8 @@ namespace Tensorflow
return typeof(bool);
case TF_DataType.TF_UINT8:
return typeof(byte);
case TF_DataType.TF_INT8:
return typeof(sbyte);
case TF_DataType.TF_INT64:
return typeof(long);
case TF_DataType.TF_UINT64:


+ 57
- 0
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
using System.Text;
using FluentAssertions;
using Google.Protobuf;
using NumSharp.Backends;
using Tensorflow;
using Tensorflow.Util;
using static Tensorflow.Binding;
@@ -131,5 +132,61 @@ namespace TensorFlowNET.UnitTest
}
}
}

[TestMethod]
public void Autocast_Case1()
{
var sess = tf.Session().as_default();
var input = tf.placeholder(tf.float64, shape: new TensorShape(6));
var op = tf.reshape(input, new int[] {2, 3});
sess.run(tf.global_variables_initializer());
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6)));

ret.Should().BeOfType<double>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
print(ret.dtype);
print(ret);
}

[TestMethod]
public void Autocast_Case2()
{
var sess = tf.Session().as_default();
var input = tf.placeholder(tf.float64, shape: new TensorShape(6));
var op = tf.reshape(input, new int[] {2, 3});
sess.run(tf.global_variables_initializer());
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f));

ret.Should().BeOfType<double>().And.BeShaped(2, 3).And.BeOfValuesApproximately(0.001d, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1);
print(ret.dtype);
print(ret);
}

[TestMethod]
public void Autocast_Case3()
{
var sess = tf.Session().as_default();
var input = tf.placeholder(tf.int16, shape: new TensorShape(6));
var op = tf.reshape(input, new int[] {2, 3});
sess.run(tf.global_variables_initializer());
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f));

ret.Should().BeOfType<short>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
print(ret.dtype);
print(ret);
}

[TestMethod]
public void Autocast_Case4()
{
var sess = tf.Session().as_default();
var input = tf.placeholder(tf.@byte, shape: new TensorShape(6));
var op = tf.reshape(input, new int[] {2, 3});
sess.run(tf.global_variables_initializer());
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f));

ret.Should().BeOfType<byte>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
print(ret.dtype);
print(ret);
}
}
}

+ 1206
- 0
test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs
File diff suppressed because it is too large
View File


Loading…
Cancel
Save