- Fixed all test cases to use using(Buffer) - Fixed all test cases to explicitly specify sessiontags/v0.12
| @@ -15,58 +15,116 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using NumSharp.Backends.Unmanaged; | |||||
| using static Tensorflow.c_api; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Represents a TF_Buffer that can be passed to Tensorflow. | |||||
| /// </summary> | |||||
| public class Buffer : DisposableObject | public class Buffer : DisposableObject | ||||
| { | { | ||||
| private TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle); | |||||
| private unsafe TF_Buffer buffer | |||||
| { | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| get => *bufferptr; | |||||
| } | |||||
| private unsafe TF_Buffer* bufferptr | |||||
| { | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| get => (TF_Buffer*) _handle; | |||||
| } | |||||
| public byte[] Data | |||||
| /// <summary> | |||||
| /// The memory block representing this buffer. | |||||
| /// </summary> | |||||
| /// <remarks>The deallocator is set to null.</remarks> | |||||
| public UnmanagedMemoryBlock<byte> MemoryBlock | |||||
| { | { | ||||
| get | |||||
| get | |||||
| { | { | ||||
| var data = new byte[buffer.length]; | |||||
| if (data.Length > 0) | |||||
| Marshal.Copy(buffer.data, data, 0, data.Length); | |||||
| return data; | |||||
| unsafe | |||||
| { | |||||
| EnsureNotDisposed(); | |||||
| var buff = (TF_Buffer*) _handle; | |||||
| return new UnmanagedMemoryBlock<byte>((byte*) buff->data.ToPointer(), (long) buff->length); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| public int Length => (int)buffer.length; | |||||
| public Buffer() | |||||
| /// <summary> | |||||
| /// The bytes length of this buffer. | |||||
| /// </summary> | |||||
| public ulong Length | |||||
| { | { | ||||
| _handle = c_api.TF_NewBuffer(); | |||||
| get | |||||
| { | |||||
| EnsureNotDisposed(); | |||||
| return buffer.length; | |||||
| } | |||||
| } | } | ||||
| public Buffer(IntPtr handle) | |||||
| public Buffer() => _handle = TF_NewBuffer(); | |||||
| internal Buffer(IntPtr handle) | |||||
| { | { | ||||
| if (handle == IntPtr.Zero) | |||||
| throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle)); | |||||
| _handle = handle; | _handle = handle; | ||||
| } | } | ||||
| public Buffer(byte[] data) | |||||
| { | |||||
| var dst = Marshal.AllocHGlobal(data.Length); | |||||
| Marshal.Copy(data, 0, dst, data.Length); | |||||
| public Buffer(byte[] data) : this(_toBuffer(data)) | |||||
| { } | |||||
| _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length); | |||||
| private static IntPtr _toBuffer(byte[] data) | |||||
| { | |||||
| if (data == null) | |||||
| throw new ArgumentNullException(nameof(data)); | |||||
| Marshal.FreeHGlobal(dst); | |||||
| unsafe | |||||
| { | |||||
| fixed (byte* src = data) | |||||
| return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength); | |||||
| } | |||||
| } | } | ||||
| public static implicit operator IntPtr(Buffer buffer) | public static implicit operator IntPtr(Buffer buffer) | ||||
| { | { | ||||
| buffer.EnsureNotDisposed(); | |||||
| return buffer._handle; | return buffer._handle; | ||||
| } | } | ||||
| public static implicit operator byte[](Buffer buffer) | |||||
| public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost. | |||||
| /// <summary> | |||||
| /// Copies this buffer's contents onto a <see cref="byte"/> array. | |||||
| /// </summary> | |||||
| public byte[] ToArray() | |||||
| { | { | ||||
| return buffer.Data; | |||||
| EnsureNotDisposed(); | |||||
| unsafe | |||||
| { | |||||
| var len = buffer.length; | |||||
| if (len == 0) | |||||
| return Array.Empty<byte>(); | |||||
| byte[] data = new byte[len]; | |||||
| fixed (byte* dst = data) | |||||
| System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len); | |||||
| return data; | |||||
| } | |||||
| } | } | ||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| => c_api.TF_DeleteBuffer(handle); | |||||
| { | |||||
| TF_DeleteBuffer(handle); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -15,6 +15,8 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -27,12 +29,12 @@ namespace Tensorflow | |||||
| if(_registered_ops == null) | if(_registered_ops == null) | ||||
| { | { | ||||
| _registered_ops = new Dictionary<string, OpDef>(); | _registered_ops = new Dictionary<string, OpDef>(); | ||||
| var handle = c_api.TF_GetAllOpList(); | |||||
| var buffer = new Buffer(handle); | |||||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||||
| foreach (var op_def in op_list.Op) | |||||
| _registered_ops[op_def.Name] = op_def; | |||||
| using (var buffer = new Buffer(c_api.TF_GetAllOpList())) | |||||
| { | |||||
| var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| foreach (var op_def in op_list.Op) | |||||
| _registered_ops[op_def.Name] = op_def; | |||||
| } | |||||
| } | } | ||||
| return _registered_ops; | return _registered_ops; | ||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| @@ -66,8 +67,9 @@ namespace Tensorflow | |||||
| /// within the context should have control dependencies on | /// within the context should have control dependencies on | ||||
| /// `control_inputs`. | /// `control_inputs`. | ||||
| /// </summary> | /// </summary> | ||||
| [SuppressMessage("ReSharper", "CoVariantArrayConversion")] | |||||
| public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | ||||
| => control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray()); | |||||
| => control_dependencies((object[])control_inputs); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns a context manager that specifies control dependencies. | /// Returns a context manager that specifies control dependencies. | ||||
| @@ -14,6 +14,9 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.IO; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class Graph | public partial class Graph | ||||
| @@ -23,21 +26,19 @@ namespace Tensorflow | |||||
| var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | c_api.TF_GraphToGraphDef(_handle, buffer, s); | ||||
| s.Check(true); | s.Check(true); | ||||
| // var def = GraphDef.Parser.ParseFrom(buffer); | |||||
| // buffer.Dispose(); | |||||
| return buffer; | return buffer; | ||||
| } | } | ||||
| private GraphDef _as_graph_def(bool add_shapes = false) | private GraphDef _as_graph_def(bool add_shapes = false) | ||||
| { | { | ||||
| var status = new Status(); | |||||
| var buffer = ToGraphDef(status); | |||||
| status.Check(true); | |||||
| status.Dispose(); | |||||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||||
| buffer.Dispose(); | |||||
| GraphDef def; | |||||
| using (var status = new Status()) | |||||
| using (var buffer = ToGraphDef(status)) | |||||
| { | |||||
| status.Check(true); | |||||
| def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | |||||
| // Strip the experimental library field iff it's empty. | // Strip the experimental library field iff it's empty. | ||||
| // if(def.Library.Function.Count == 0) | // if(def.Library.Function.Count == 0) | ||||
| @@ -45,7 +46,7 @@ namespace Tensorflow | |||||
| return def; | return def; | ||||
| } | } | ||||
| public GraphDef as_graph_def(bool add_shapes = false) | |||||
| public GraphDef as_graph_def(bool add_shapes = false) | |||||
| => _as_graph_def(add_shapes); | => _as_graph_def(add_shapes); | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -18,6 +18,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -30,7 +31,7 @@ namespace Tensorflow | |||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| { | { | ||||
| c_api.TF_GraphGetOpDef(_handle, type, buffer, status); | c_api.TF_GraphGetOpDef(_handle, type, buffer, status); | ||||
| return OpDef.Parser.ParseFrom(buffer.Data); | |||||
| return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -71,7 +72,7 @@ namespace Tensorflow | |||||
| public ITensorOrOperation[] get_operations() | public ITensorOrOperation[] get_operations() | ||||
| { | { | ||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||||
| return _nodes_by_name.Values.ToArray(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -85,7 +86,7 @@ namespace Tensorflow | |||||
| public ITensorOrOperation _get_operation_by_name_unsafe(string name) | public ITensorOrOperation _get_operation_by_name_unsafe(string name) | ||||
| { | { | ||||
| return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null; | |||||
| return _nodes_by_name.TryGetValue(name, out var val) ? val : null; | |||||
| } | } | ||||
| public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) | public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) | ||||
| @@ -17,7 +17,9 @@ | |||||
| using Google.Protobuf.Collections; | using Google.Protobuf.Collections; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -226,9 +228,12 @@ namespace Tensorflow | |||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| using (var buf = new Buffer()) | using (var buf = new Buffer()) | ||||
| { | { | ||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
| status.Check(true); | |||||
| x = AttrValue.Parser.ParseFrom(buf); | |||||
| unsafe | |||||
| { | |||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
| status.Check(true); | |||||
| x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||||
| } | |||||
| } | } | ||||
| string oneof_value = x.ValueCase.ToString(); | string oneof_value = x.ValueCase.ToString(); | ||||
| @@ -259,7 +264,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| c_api.TF_OperationToNodeDef(_handle, buffer, s); | c_api.TF_OperationToNodeDef(_handle, buffer, s); | ||||
| s.Check(); | s.Check(); | ||||
| return NodeDef.Parser.ParseFrom(buffer); | |||||
| return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -299,8 +304,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public TF_Output _tf_output(int output_idx) | public TF_Output _tf_output(int output_idx) | ||||
| { | { | ||||
| var tf_output = new TF_Output(op, output_idx); | |||||
| return tf_output; | |||||
| return new TF_Output(op, output_idx); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -308,8 +312,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public TF_Input _tf_input(int input_idx) | public TF_Input _tf_input(int input_idx) | ||||
| { | { | ||||
| var tf_input = new TF_Input(op, input_idx); | |||||
| return tf_input; | |||||
| return new TF_Input(op, input_idx); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -37,8 +37,8 @@ namespace Tensorflow | |||||
| public void SetConfig(ConfigProto config) | public void SetConfig(ConfigProto config) | ||||
| { | { | ||||
| var bytes = config.ToByteArray(); | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||||
| var bytes = config.ToByteArray(); //TODO! we can use WriteTo | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak | |||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | Marshal.Copy(bytes, 0, proto, bytes.Length); | ||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| @@ -230,8 +230,8 @@ namespace Tensorflow | |||||
| // Add attrs | // Add attrs | ||||
| foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
| { | { | ||||
| var bytes = attr.Value.ToByteArray(); | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | Marshal.Copy(bytes, 0, proto, bytes.Length); | ||||
| uint len = (uint)bytes.Length; | uint len = (uint)bytes.Length; | ||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | ||||
| @@ -64,8 +64,7 @@ namespace Tensorflow | |||||
| public Session Session() | public Session Session() | ||||
| { | { | ||||
| defaultSession = new Session(); | |||||
| return defaultSession; | |||||
| return new Session(); | |||||
| } | } | ||||
| public Session Session(Graph graph) | public Session Session(Graph graph) | ||||
| @@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples | |||||
| // Display logs per epoch step | // Display logs per epoch step | ||||
| if ((epoch + 1) % display_step == 0) | if ((epoch + 1) % display_step == 0) | ||||
| print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms"); | |||||
| print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms"); | |||||
| sw.Reset(); | sw.Reset(); | ||||
| } | } | ||||
| @@ -114,8 +114,8 @@ namespace TensorFlowNET.Examples | |||||
| var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | ||||
| // Calculate accuracy | // Calculate accuracy | ||||
| var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | ||||
| float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||||
| print($"Accuracy: {acc.ToString("F4")}"); | |||||
| float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||||
| print($"Accuracy: {acc:F4}"); | |||||
| return acc > 0.9; | return acc > 0.9; | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
| @@ -14,21 +15,22 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| var expected = new[] { false, true, false, false, true, false, true }; | var expected = new[] { false, true, false, false, true, false, true }; | ||||
| var spike = tf.Variable(false); | var spike = tf.Variable(false); | ||||
| spike.initializer.run(); | |||||
| foreach (var i in range(1, 2)) | |||||
| using (var sess = new Session()) | |||||
| { | { | ||||
| if (raw_data[i] - raw_data[i - 1] > 5d) | |||||
| { | |||||
| var updater = tf.assign(spike, tf.constant(true)); | |||||
| updater.eval(); | |||||
| } | |||||
| else | |||||
| spike.initializer.run(session: sess); | |||||
| foreach (var i in range(1, 2)) | |||||
| { | { | ||||
| tf.assign(spike, tf.constant(true)).eval(); | |||||
| } | |||||
| if (raw_data[i] - raw_data[i - 1] > 5d) | |||||
| { | |||||
| var updater = tf.assign(spike, tf.constant(true)); | |||||
| updater.eval(sess); | |||||
| } else | |||||
| { | |||||
| tf.assign(spike, tf.constant(true)).eval(sess); | |||||
| } | |||||
| Assert.AreEqual((bool)spike.eval(), expected[i - 1]); | |||||
| Assert.AreEqual((bool) spike.eval(), expected[i - 1]); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -2,6 +2,7 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Util; | |||||
| using Buffer = Tensorflow.Buffer; | using Buffer = Tensorflow.Buffer; | ||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| @@ -45,15 +46,18 @@ namespace TensorFlowNET.UnitTest | |||||
| private bool GetGraphDef(Graph graph, out GraphDef graph_def) | private bool GetGraphDef(Graph graph, out GraphDef graph_def) | ||||
| { | { | ||||
| graph_def = null; | graph_def = null; | ||||
| var s = new Status(); | |||||
| var buffer = new Buffer(); | |||||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
| bool ret = TF_GetCode(s) == TF_OK; | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||||
| if (ret) graph_def = GraphDef.Parser.ParseFrom(buffer.Data); | |||||
| buffer.Dispose(); | |||||
| s.Dispose(); | |||||
| return ret; | |||||
| using (var s = new Status()) | |||||
| { | |||||
| using (var buffer = new Buffer()) | |||||
| { | |||||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
| bool ret = TF_GetCode(s) == TF_OK; | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||||
| if (ret) | |||||
| graph_def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) | private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) | ||||
| @@ -322,7 +322,6 @@ namespace TensorFlowNET.UnitTest | |||||
| EXPECT_EQ(feed2, control_inputs[1]); | EXPECT_EQ(feed2, control_inputs[1]); | ||||
| // Export to a graph def so we can import a graph with control dependencies | // Export to a graph def so we can import a graph with control dependencies | ||||
| graph_def.Dispose(); | |||||
| graph_def = new Buffer(); | graph_def = new Buffer(); | ||||
| c_api.TF_GraphToGraphDef(graph, graph_def, s); | c_api.TF_GraphToGraphDef(graph, graph_def, s); | ||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | EXPECT_EQ(TF_Code.TF_OK, s.Code); | ||||
| @@ -346,14 +345,10 @@ namespace TensorFlowNET.UnitTest | |||||
| EXPECT_EQ(feed4, control_inputs[1]); | EXPECT_EQ(feed4, control_inputs[1]); | ||||
| c_api.TF_DeleteImportGraphDefOptions(opts); | c_api.TF_DeleteImportGraphDefOptions(opts); | ||||
| c_api.TF_DeleteBuffer(graph_def); | |||||
| // Can add nodes to the imported graph without trouble. | // Can add nodes to the imported graph without trouble. | ||||
| c_test_util.Add(feed, scalar, graph, s); | c_test_util.Add(feed, scalar, graph, s); | ||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | ASSERT_EQ(TF_Code.TF_OK, s.Code); | ||||
| graph.Dispose(); | |||||
| s.Dispose(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -42,5 +43,39 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.AreEqual("", g._name_stack); | Assert.AreEqual("", g._name_stack); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void NestedNameScope_Using() | |||||
| { | |||||
| Graph g = tf.Graph().as_default(); | |||||
| using (var name = new ops.NameScope("scope1")) | |||||
| { | |||||
| Assert.AreEqual("scope1", g._name_stack); | |||||
| Assert.AreEqual("scope1/", name); | |||||
| var const1 = tf.constant(1.0); | |||||
| Assert.AreEqual("scope1/Const:0", const1.name); | |||||
| using (var name2 = new ops.NameScope("scope2")) | |||||
| { | |||||
| Assert.AreEqual("scope1/scope2", g._name_stack); | |||||
| Assert.AreEqual("scope1/scope2/", name); | |||||
| var const2 = tf.constant(2.0); | |||||
| Assert.AreEqual("scope1/scope2/Const:0", const2.name); | |||||
| } | |||||
| Assert.AreEqual("scope1", g._name_stack); | |||||
| var const3 = tf.constant(2.0); | |||||
| Assert.AreEqual("scope1/Const_1:0", const3.name); | |||||
| } | |||||
| ; | |||||
| g.Dispose(); | |||||
| Assert.AreEqual("", g._name_stack); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,6 +4,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using NumSharp; | using NumSharp; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Util; | |||||
| using Buffer = Tensorflow.Buffer; | using Buffer = Tensorflow.Buffer; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -21,7 +22,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| var handle = c_api.TF_GetAllOpList(); | var handle = c_api.TF_GetAllOpList(); | ||||
| var buffer = new Buffer(handle); | var buffer = new Buffer(handle); | ||||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||||
| var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| var _registered_ops = new Dictionary<string, OpDef>(); | var _registered_ops = new Dictionary<string, OpDef>(); | ||||
| foreach (var op_def in op_list.Op) | foreach (var op_def in op_list.Op) | ||||
| @@ -165,7 +165,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| { | { | ||||
| var ndarray=tensor.eval(); | |||||
| var ndarray=tensor.eval(sess); | |||||
| if (typeof(T) == typeof(double)) | if (typeof(T) == typeof(double)) | ||||
| { | { | ||||
| double x = ndarray; | double x = ndarray; | ||||
| @@ -72,8 +72,6 @@ namespace TensorFlowNET.UnitTest | |||||
| // Clean up | // Clean up | ||||
| csession.CloseAndDelete(s); | csession.CloseAndDelete(s); | ||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | ASSERT_EQ(TF_Code.TF_OK, s.Code); | ||||
| graph.Dispose(); | |||||
| s.Dispose(); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -84,7 +82,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var c = math_ops.matmul(a, b, name: "matmul"); | var c = math_ops.matmul(a, b, name: "matmul"); | ||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| { | { | ||||
| var result = c.eval(); | |||||
| var result = c.eval(sess); | |||||
| Assert.AreEqual(6, result.Data<double>()[0]); | Assert.AreEqual(6, result.Data<double>()[0]); | ||||
| } | } | ||||
| } | } | ||||
| @@ -119,7 +119,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| sess.run(init_op); | sess.run(init_op); | ||||
| // o some work with the model. | // o some work with the model. | ||||
| inc_v1.op.run(); | |||||
| inc_v1.op.run(session: sess); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,6 @@ | |||||
| using Tensorflow; | |||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Util; | |||||
| using Buffer = Tensorflow.Buffer; | using Buffer = Tensorflow.Buffer; | ||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| @@ -26,12 +28,15 @@ namespace TensorFlowNET.UnitTest | |||||
| return op; | return op; | ||||
| } | } | ||||
| [SuppressMessage("ReSharper", "RedundantAssignment")] | |||||
| public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | ||||
| { | { | ||||
| var buffer = new Buffer(); | |||||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||||
| attr_value = AttrValue.Parser.ParseFrom(buffer); | |||||
| buffer.Dispose(); | |||||
| using (var buffer = new Buffer()) | |||||
| { | |||||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||||
| attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | |||||
| return s.Code == TF_Code.TF_OK; | return s.Code == TF_Code.TF_OK; | ||||
| } | } | ||||
| @@ -42,7 +47,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | c_api.TF_GraphToGraphDef(graph, buffer, s); | ||||
| s.Check(); | s.Check(); | ||||
| return GraphDef.Parser.ParseFrom(buffer); | |||||
| return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -24,16 +24,17 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| [TestMethod] | [TestMethod] | ||||
| public void TestShape() | public void TestShape() | ||||
| { | { | ||||
| var g = tf.Graph().as_default(); | |||||
| var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); | |||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); | |||||
| var op = g._create_op_from_tf_operation(c_op); | |||||
| Assert.AreEqual("myop", op.name); | |||||
| Assert.AreEqual("Identity", op.type); | |||||
| Assert.AreEqual(1, len(op.outputs)); | |||||
| assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); | |||||
| using (var g = tf.Graph().as_default()) | |||||
| { | |||||
| var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | |||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
| var op = g._create_op_from_tf_operation(c_op); | |||||
| Assert.AreEqual("myop", op.name); | |||||
| Assert.AreEqual("Identity", op.type); | |||||
| Assert.AreEqual(1, len(op.outputs)); | |||||
| assertItemsEqual(new[] {2, 3}, op.outputs[0].shape); | |||||
| } | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||