Browse Source

Fix assign_add_variable_op for Graph mode.

tags/v0.20
Oceania2018 5 years ago
parent
commit
89987e2463
5 changed files with 25 additions and 6 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  2. +3
    -1
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
  3. +19
    -2
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  4. +0
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
  5. +2
    -0
      src/TensorFlowNET.Core/tensorflow.memory.cs

+ 1
- 1
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -84,7 +84,7 @@ namespace Tensorflow
unsafe unsafe
{ {
fixed (byte* src = data) fixed (byte* src = data)
return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength);
return TF_NewBufferFromString(new IntPtr(src), (ulong)data.LongLength);
} }
} }




+ 3
- 1
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs View File

@@ -57,7 +57,9 @@ namespace Tensorflow
return null; return null;
} }


return null;
var _op = tf._op_def_lib._apply_op_helper("AssignAddVariableOp", name, new { resource, value });

return _op;
} }


public static Operation assign_variable_op(Tensor resource, Tensor value, string name = null) public static Operation assign_variable_op(Tensor resource, Tensor value, string name = null)


+ 19
- 2
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -18,6 +18,9 @@ using System;
using System.Linq; using System.Linq;
using Tensorflow.Framework; using Tensorflow.Framework;
using static Tensorflow.CppShapeInferenceResult.Types; using static Tensorflow.CppShapeInferenceResult.Types;
using static Tensorflow.Binding;
using System.Collections.Generic;
using System.Runtime.InteropServices;


namespace Tensorflow namespace Tensorflow
{ {
@@ -106,7 +109,7 @@ namespace Tensorflow
public static Tensor variable_handle_from_shape_and_dtype(TensorShape shape, TF_DataType dtype, public static Tensor variable_handle_from_shape_and_dtype(TensorShape shape, TF_DataType dtype,
string shared_name, string name, bool graph_mode, Tensor initial_value = null) string shared_name, string name, bool graph_mode, Tensor initial_value = null)
{ {
var container = "";// ops.get_default_graph().container;
var container = ops.get_default_graph().Container;
var handle = gen_resource_variable_ops.var_handle_op(shape: shape, var handle = gen_resource_variable_ops.var_handle_op(shape: shape,
dtype: dtype, dtype: dtype,
shared_name: shared_name, shared_name: shared_name,
@@ -153,10 +156,24 @@ namespace Tensorflow
/// <param name="handle"></param> /// <param name="handle"></param>
/// <param name="handle_data"></param> /// <param name="handle_data"></param>
/// <param name="graph_mode"></param> /// <param name="graph_mode"></param>
private static void _set_handle_shapes_and_types(Tensor handle, HandleData handle_data, bool graph_mode)
private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode)
{ {
if (!graph_mode) if (!graph_mode)
return; return;

var size = handle_data.ShapeAndType.Count;

var shapes = new IntPtr[size];
var types = new DataType[size];
var ranks = new int[size];

for (int i = 0; i < size; i++)
{
var shapeAndType = handle_data.ShapeAndType[i];
types[i] = shapeAndType.Dtype;
ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count;
var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray();
}
} }


/// <summary> /// <summary>


+ 0
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs View File

@@ -10,8 +10,6 @@ namespace Tensorflow
{ {
public static implicit operator IntPtr(Tensor tensor) public static implicit operator IntPtr(Tensor tensor)
{ {
if (tensor._handle == IntPtr.Zero)
Console.WriteLine("tensor is not allocated.");
return tensor._handle; return tensor._handle;
} }




+ 2
- 0
src/TensorFlowNET.Core/tensorflow.memory.cs View File

@@ -47,6 +47,8 @@ namespace Tensorflow
public unsafe void memcpy<T>(IntPtr dst, T[] src, long size) public unsafe void memcpy<T>(IntPtr dst, T[] src, long size)
where T : unmanaged where T : unmanaged
{ {
if (src.Length == 0) return;

fixed (void* p = &src[0]) fixed (void* p = &src[0])
System.Buffer.MemoryCopy(p, dst.ToPointer(), size, size); System.Buffer.MemoryCopy(p, dst.ToPointer(), size, size);
} }


Loading…
Cancel
Save