|
|
@@ -18,6 +18,9 @@ using System; |
|
|
using System.Linq; |
|
|
using System.Linq; |
|
|
using Tensorflow.Framework; |
|
|
using Tensorflow.Framework; |
|
|
using static Tensorflow.CppShapeInferenceResult.Types; |
|
|
using static Tensorflow.CppShapeInferenceResult.Types; |
|
|
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
using System.Collections.Generic; |
|
|
|
|
|
using System.Runtime.InteropServices; |
|
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
namespace Tensorflow |
|
|
{ |
|
|
{ |
|
|
@@ -106,7 +109,7 @@ namespace Tensorflow |
|
|
public static Tensor variable_handle_from_shape_and_dtype(TensorShape shape, TF_DataType dtype, |
|
|
public static Tensor variable_handle_from_shape_and_dtype(TensorShape shape, TF_DataType dtype, |
|
|
string shared_name, string name, bool graph_mode, Tensor initial_value = null) |
|
|
string shared_name, string name, bool graph_mode, Tensor initial_value = null) |
|
|
{ |
|
|
{ |
|
|
var container = "";// ops.get_default_graph().container; |
|
|
|
|
|
|
|
|
var container = ops.get_default_graph().Container; |
|
|
var handle = gen_resource_variable_ops.var_handle_op(shape: shape, |
|
|
var handle = gen_resource_variable_ops.var_handle_op(shape: shape, |
|
|
dtype: dtype, |
|
|
dtype: dtype, |
|
|
shared_name: shared_name, |
|
|
shared_name: shared_name, |
|
|
@@ -153,10 +156,24 @@ namespace Tensorflow |
|
|
/// <param name="handle"></param> |
|
|
/// <param name="handle"></param> |
|
|
/// <param name="handle_data"></param> |
|
|
/// <param name="handle_data"></param> |
|
|
/// <param name="graph_mode"></param> |
|
|
/// <param name="graph_mode"></param> |
|
|
private static void _set_handle_shapes_and_types(Tensor handle, HandleData handle_data, bool graph_mode) |
|
|
|
|
|
|
|
|
private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) |
|
|
{ |
|
|
{ |
|
|
if (!graph_mode) |
|
|
if (!graph_mode) |
|
|
return; |
|
|
return; |
|
|
|
|
|
|
|
|
|
|
|
var size = handle_data.ShapeAndType.Count; |
|
|
|
|
|
|
|
|
|
|
|
var shapes = new IntPtr[size]; |
|
|
|
|
|
var types = new DataType[size]; |
|
|
|
|
|
var ranks = new int[size]; |
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < size; i++) |
|
|
|
|
|
{ |
|
|
|
|
|
var shapeAndType = handle_data.ShapeAndType[i]; |
|
|
|
|
|
types[i] = shapeAndType.Dtype; |
|
|
|
|
|
ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; |
|
|
|
|
|
var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
/// <summary> |
|
|
/// <summary> |
|
|
|