From 89987e24632063a52d93467219c143a9b6b4b47d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 11 Jul 2020 07:56:31 -0500 Subject: [PATCH] Fix assign_add_variable_op for Graph mode. --- src/TensorFlowNET.Core/Buffers/Buffer.cs | 2 +- .../Operations/gen_resource_variable_ops.cs | 4 +++- .../Operations/resource_variable_ops.cs | 21 +++++++++++++++++-- .../Tensors/Tensor.Implicit.cs | 2 -- src/TensorFlowNET.Core/tensorflow.memory.cs | 2 ++ 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index d810dbf2..a537ce9f 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -84,7 +84,7 @@ namespace Tensorflow unsafe { fixed (byte* src = data) - return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength); + return TF_NewBufferFromString(new IntPtr(src), (ulong)data.LongLength); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index 102e10ee..3063ab35 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -57,7 +57,9 @@ namespace Tensorflow return null; } - return null; + var _op = tf._op_def_lib._apply_op_helper("AssignAddVariableOp", name, new { resource, value }); + + return _op; } public static Operation assign_variable_op(Tensor resource, Tensor value, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 7c769ae1..6d5c9d95 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -18,6 +18,9 @@ using System; using System.Linq; using Tensorflow.Framework; using static Tensorflow.CppShapeInferenceResult.Types; +using static Tensorflow.Binding; +using System.Collections.Generic; +using System.Runtime.InteropServices; namespace Tensorflow { @@ -106,7 +109,7 @@ namespace Tensorflow public static Tensor variable_handle_from_shape_and_dtype(TensorShape shape, TF_DataType dtype, string shared_name, string name, bool graph_mode, Tensor initial_value = null) { - var container = "";// ops.get_default_graph().container; + var container = ops.get_default_graph().Container; var handle = gen_resource_variable_ops.var_handle_op(shape: shape, dtype: dtype, shared_name: shared_name, @@ -153,10 +156,24 @@ namespace Tensorflow /// /// /// - private static void _set_handle_shapes_and_types(Tensor handle, HandleData handle_data, bool graph_mode) + private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) { if (!graph_mode) return; + + var size = handle_data.ShapeAndType.Count; + + var shapes = new IntPtr[size]; + var types = new DataType[size]; + var ranks = new int[size]; + + for (int i = 0; i < size; i++) + { + var shapeAndType = handle_data.ShapeAndType[i]; + types[i] = shapeAndType.Dtype; + ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; + var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); + } } /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs index cabaae24..bcf895db 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs @@ -10,8 +10,6 @@ namespace Tensorflow { public static implicit operator IntPtr(Tensor tensor) { - if (tensor._handle == IntPtr.Zero) - Console.WriteLine("tensor is not allocated."); return tensor._handle; } diff --git a/src/TensorFlowNET.Core/tensorflow.memory.cs b/src/TensorFlowNET.Core/tensorflow.memory.cs index 1c1e8ddd..f9f0e8f3 100644 --- a/src/TensorFlowNET.Core/tensorflow.memory.cs +++ b/src/TensorFlowNET.Core/tensorflow.memory.cs @@ -47,6 +47,8 @@ namespace Tensorflow public unsafe void memcpy(IntPtr dst, T[] src, long size) where T : unmanaged { + if (src.Length == 0) return; + fixed (void* p = &src[0]) System.Buffer.MemoryCopy(p, dst.ToPointer(), size, size); }