Browse Source

BaseSession: revamped fetchValue (perf-op)

tags/v0.12
Eli Belash 6 years ago
parent
commit
e56164de00
2 changed files with 252 additions and 149 deletions
  1. +157
    -98
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +95
    -51
      test/TensorFlowNET.UnitTest/SessionTest.cs

+ 157
- 98
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -21,6 +21,9 @@ using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Text;
using Google.Protobuf;
using NumSharp.Backends;
using Tensorflow.Util;

namespace Tensorflow
{
@@ -246,111 +249,167 @@ namespace Tensorflow
return result;
}

private unsafe NDArray fetchValue(IntPtr output)
private static unsafe NDArray fetchValue(IntPtr output)
{
var tensor = new Tensor(output);
NDArray nd = null;
Type type = tensor.dtype.as_numpy_dtype();
var ndims = tensor.shape;
var offset = (byte*) c_api.TF_TensorData(output);

if(ndims.Length == 0)
NDArray ret;
using (var tensor = new Tensor(output))
{
switch (tensor.dtype)
var ndims = tensor.shape;
var srcAddress = c_api.TF_TensorData(output).ToInt64();

if (ndims.Length == 0)
{
case TF_DataType.TF_BOOL:
nd = NDArray.Scalar(*(bool*)offset);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = NDArray.FromString(str);
break;
case TF_DataType.TF_UINT8:
nd = NDArray.Scalar(*(byte*)offset);
break;
case TF_DataType.TF_INT16:
nd = NDArray.Scalar(*(short*)offset);
break;
case TF_DataType.TF_INT32:
nd = NDArray.Scalar(*(int*)offset);
break;
case TF_DataType.TF_INT64:
nd = NDArray.Scalar(*(long*)offset);
break;
case TF_DataType.TF_FLOAT:
nd = NDArray.Scalar(*(float*)offset);
break;
case TF_DataType.TF_DOUBLE:
nd = NDArray.Scalar(*(double*)offset);
break;
default:
throw new NotImplementedException("can't fetch output");
}
}
else
{
switch (tensor.dtype)
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
ret = NDArray.Scalar(*(bool*) srcAddress);
break;
case TF_DataType.TF_STRING:
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
ret = NDArray.FromString(reader.ReadString());
break;
case TF_DataType.TF_UINT8:
ret = NDArray.Scalar(*(byte*) srcAddress);
break;
case TF_DataType.TF_INT16:
ret = NDArray.Scalar(*(short*) srcAddress);
break;
case TF_DataType.TF_INT32:
ret = NDArray.Scalar(*(int*) srcAddress);
break;
case TF_DataType.TF_INT64:
ret = NDArray.Scalar(*(long*) srcAddress);
break;
case TF_DataType.TF_UINT16:
ret = NDArray.Scalar(*(ushort*) srcAddress);
break;
case TF_DataType.TF_UINT32:
ret = NDArray.Scalar(*(uint*) srcAddress);
break;
case TF_DataType.TF_UINT64:
ret = NDArray.Scalar(*(ulong*) srcAddress);
break;
case TF_DataType.TF_FLOAT:
ret = NDArray.Scalar(*(float*) srcAddress);
break;
case TF_DataType.TF_DOUBLE:
ret = NDArray.Scalar(*(double*) srcAddress);
break;
default:
throw new NotImplementedException("can't fetch output");
}
} else
{
case TF_DataType.TF_BOOL:
var bools = new bool[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str);
break;
case TF_DataType.TF_UINT8:
var _bytes = new byte[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
nd = np.array(_bytes).reshape(ndims);
break;
case TF_DataType.TF_INT16:
var shorts = new short[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
nd = np.array(shorts).reshape(ndims);
break;
case TF_DataType.TF_INT32:
var ints = new int[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
nd = np.array(ints).reshape(ndims);
break;
case TF_DataType.TF_INT64:
var longs = new long[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
longs[i] = *(long*)(offset + (int)(tensor.itemsize * i));
nd = np.array(longs).reshape(ndims);
break;
case TF_DataType.TF_FLOAT:
var floats = new float[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
nd = np.array(floats).reshape(ndims);
break;
case TF_DataType.TF_DOUBLE:
var doubles = new double[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
nd = np.array(doubles).reshape(ndims);
break;
default:
throw new NotImplementedException("can't fetch output");
//var size = (long) tensor.size;
//var itemsize = (long) tensor.itemsize;
var bytesize = (long) tensor.bytesize;
var src = (void*) srcAddress;

#if _REGEN
#region Compute
switch (tensor.dtype)
{
%foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")%
case TF_DataType.#3:
{
ret = new NDArray(NPTypeCode.#1, ndims, false);
System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize);
break;
}
%
case TF_DataType.TF_STRING:
{
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
ret = NDArray.FromString(reader.ReadString());
break;
}
default:
throw new NotSupportedException();
}
#endregion
#else

#region Compute
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
{
ret = new NDArray(NPTypeCode.Boolean, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_UINT8:
{
ret = new NDArray(NPTypeCode.Byte, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_INT16:
{
ret = new NDArray(NPTypeCode.Int16, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_UINT16:
{
ret = new NDArray(NPTypeCode.UInt16, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_INT32:
{
ret = new NDArray(NPTypeCode.Int32, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_UINT32:
{
ret = new NDArray(NPTypeCode.UInt32, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_INT64:
{
ret = new NDArray(NPTypeCode.Int64, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_UINT64:
{
ret = new NDArray(NPTypeCode.UInt64, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_DOUBLE:
{
ret = new NDArray(NPTypeCode.Double, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_FLOAT:
{
ret = new NDArray(NPTypeCode.Single, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_STRING:
{
throw new NotImplementedException();
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
ret = NDArray.FromString(reader.ReadString());
break;
}
default:
throw new NotSupportedException();
}
#endregion
#endif
}
}
tensor.Dispose();

return nd;
return ret;
}

/// <summary>


+ 95
- 51
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -2,6 +2,11 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using FluentAssertions;
using Google.Protobuf;
using Tensorflow;
using static Tensorflow.Binding;

@@ -17,74 +22,113 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void Session()
{
var s = new Status();
var graph = new Graph();
lock (this)
{
var s = new Status();
var graph = new Graph();

// Make a placeholder operation.
var feed = c_test_util.Placeholder(graph, s);
// Make a placeholder operation.
var feed = c_test_util.Placeholder(graph, s);

// Make a constant operation with the scalar "2".
var two = c_test_util.ScalarConst(2, graph, s);
// Make a constant operation with the scalar "2".
var two = c_test_util.ScalarConst(2, graph, s);

// Add operation.
var add = c_test_util.Add(feed, two, graph, s);
// Add operation.
var add = c_test_util.Add(feed, two, graph, s);

var csession = new CSession(graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
var csession = new CSession(graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

// Run the graph.
var inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(3));
csession.SetInputs(inputs);
// Run the graph.
var inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(3));
csession.SetInputs(inputs);

var outputs = new TF_Output[] { new TF_Output(add, 0) };
csession.SetOutputs(outputs);
var outputs = new TF_Output[] {new TF_Output(add, 0)};
csession.SetOutputs(outputs);

csession.Run(s);
Tensor outTensor = csession.output_tensor(0);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims);
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.ToArray<int>();
EXPECT_EQ(3 + 2, output_contents[0]);
csession.Run(s);
Tensor outTensor = csession.output_tensor(0);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims);
ASSERT_EQ((ulong) sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.ToArray<int>();
EXPECT_EQ(3 + 2, output_contents[0]);

// Add another operation to the graph.
var neg = c_test_util.Neg(add, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
// Add another operation to the graph.
var neg = c_test_util.Neg(add, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

// Run up to the new operation.
inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(7));
csession.SetInputs(inputs);
outputs = new TF_Output[] { new TF_Output(neg, 0) };
csession.SetOutputs(outputs);
csession.Run(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
// Run up to the new operation.
inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(7));
csession.SetInputs(inputs);
outputs = new TF_Output[] {new TF_Output(neg, 0)};
csession.SetOutputs(outputs);
csession.Run(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

outTensor = csession.output_tensor(0);
ASSERT_TRUE(outTensor != IntPtr.Zero);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
output_contents = outTensor.ToArray<int>();
EXPECT_EQ(-(7 + 2), output_contents[0]);
outTensor = csession.output_tensor(0);
ASSERT_TRUE(outTensor != IntPtr.Zero);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar
ASSERT_EQ((ulong) sizeof(uint), outTensor.bytesize);
output_contents = outTensor.ToArray<int>();
EXPECT_EQ(-(7 + 2), output_contents[0]);

// Clean up
csession.CloseAndDelete(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
// Clean up
csession.CloseAndDelete(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
}
}

[TestMethod]
public void EvalTensor()
{
var a = constant_op.constant(np.array(3.0).reshape(1, 1));
var b = constant_op.constant(np.array(2.0).reshape(1, 1));
var c = math_ops.matmul(a, b, name: "matmul");
using (var sess = tf.Session())
lock (this)
{
var a = constant_op.constant(np.array(3.0).reshape(1, 1));
var b = constant_op.constant(np.array(2.0).reshape(1, 1));
var c = math_ops.matmul(a, b, name: "matmul");
using (var sess = tf.Session())
{
var result = c.eval(sess);
Assert.AreEqual(6, result.Data<double>()[0]);
}
}
}

[TestMethod]
public void Eval_SmallString_Scalar()
{
lock (this)
{
var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING);
var c = tf.strings.substr(a, 4, 8);
using (var sess = tf.Session())
{
var result = (string) c.eval(sess);
Console.WriteLine(result);
result.Should().Be("heythere");
}
}
}

[TestMethod]
public void Eval_LargeString_Scalar()
{
lock (this)
{
var result = c.eval(sess);
Assert.AreEqual(6, result.Data<double>()[0]);
const int size = 30_000;
var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING);
var c = tf.strings.substr(a, 0, size - 5000);
using (var sess = tf.Session())
{
var result = (string) c.eval(sess);
Console.WriteLine((string) result);
result.Should().HaveLength(size - 5000).And.ContainAll("a");
}
}
}
}
}
}

Loading…
Cancel
Save