| @@ -51,17 +51,13 @@ namespace Tensorflow | |||||
| return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | ||||
| } | } | ||||
| public unsafe static byte[] ByteStringPiece(IntPtr handle) | |||||
| public unsafe static byte[] ByteStringPiece(Buffer? handle) | |||||
| { | { | ||||
| byte* str_data = (byte*)handle.ToPointer(); | |||||
| List<byte> bytes = new List<byte>(); | |||||
| byte current = 255; | |||||
| while (current != ((byte)'\0')) | |||||
| { | |||||
| current = *(str_data++); | |||||
| bytes.Add(current); | |||||
| if(handle is null){ | |||||
| return new byte[0]; | |||||
| } | } | ||||
| return bytes.Take(bytes.Count - 1).ToArray(); | |||||
| var data = handle.ToArray(); | |||||
| return data; | |||||
| } | } | ||||
| [UnmanagedFunctionPointer(CallingConvention.Winapi)] | [UnmanagedFunctionPointer(CallingConvention.Winapi)] | ||||
| @@ -10,7 +10,7 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
| public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | ||||
| } | } | ||||
| @@ -0,0 +1,25 @@ | |||||
| using Tensorflow; | |||||
| internal static class GraphOnlyOps | |||||
| { | |||||
| /// <summary> | |||||
| /// Graph-only version of tf.compat.v1.placeholder(), for internal use only. | |||||
| /// </summary> | |||||
| /// <param name="dtyype"></param> | |||||
| /// <param name="shape"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null) | |||||
| { | |||||
| var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() }; | |||||
| var shape_value = new AttrValue() { Shape = shape.as_proto() }; | |||||
| var g = ops.get_default_graph(); | |||||
| Dictionary<string, AttrValue> attrs = new(); | |||||
| attrs["dtype"] = dtype_value; | |||||
| attrs["shape"] = shape_value; | |||||
| var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype }, | |||||
| new TF_DataType[0], attrs: attrs, name: name); | |||||
| var result = op.outputs[0]; | |||||
| return result; | |||||
| } | |||||
| } | |||||
| @@ -544,12 +544,12 @@ public class FuncGraph : Graph, IDisposable | |||||
| Tensor placeholder; | Tensor placeholder; | ||||
| try | try | ||||
| { | { | ||||
| placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); | |||||
| placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name); | |||||
| } | } | ||||
| catch (ValueError) | |||||
| catch (ValueError ex) | |||||
| { | { | ||||
| // TODO(Rinne): Add warning here. | |||||
| placeholder = tf.placeholder(tensor.dtype, tensor.shape); | |||||
| tf.Logger.Warning(ex.ToString()); | |||||
| placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape); | |||||
| } | } | ||||
| handle_data_util.copy_handle_data(tensor, placeholder); | handle_data_util.copy_handle_data(tensor, placeholder); | ||||
| if (name is not null) | if (name is not null) | ||||
| @@ -575,12 +575,12 @@ public class FuncGraph : Graph, IDisposable | |||||
| Tensor placeholder; | Tensor placeholder; | ||||
| try | try | ||||
| { | { | ||||
| placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); | |||||
| placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name); | |||||
| } | } | ||||
| catch (ValueError) | catch (ValueError) | ||||
| { | { | ||||
| // TODO(Rinne): Add warning here. | // TODO(Rinne): Add warning here. | ||||
| placeholder = tf.placeholder(spec.dtype, spec.shape); | |||||
| placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape); | |||||
| } | } | ||||
| if (name is not null) | if (name is not null) | ||||
| { | { | ||||
| @@ -31,7 +31,7 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return ops.convert_to_tensor(shape); | |||||
| return ops.convert_to_tensor(shape, dtype: dtypes.int32); | |||||
| } | } | ||||
| } | } | ||||
| @@ -38,9 +38,9 @@ namespace Tensorflow.Operations | |||||
| int len_orig_loop_vars = orig_loop_vars.Length; | int len_orig_loop_vars = orig_loop_vars.Length; | ||||
| loop_vars = _tensor_array_to_flow(loop_vars); | loop_vars = _tensor_array_to_flow(loop_vars); | ||||
| loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors(); | |||||
| loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x), loop_vars).ToTensors(); | |||||
| var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars)); | |||||
| var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), loop_vars); | |||||
| var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); | var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); | ||||
| @@ -379,10 +379,9 @@ namespace Tensorflow.Operations | |||||
| return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); | return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); | ||||
| } | } | ||||
| private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype, | |||||
| string name) | |||||
| private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value) | |||||
| { | { | ||||
| return ops.convert_to_tensor(value, dtype, name, false); | |||||
| return ops.convert_to_tensor(value, as_ref: false); | |||||
| } | } | ||||
| private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) | private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) | ||||
| @@ -576,7 +576,8 @@ namespace Tensorflow | |||||
| public static HandleData get_resource_handle_data(Tensor graph_op) | public static HandleData get_resource_handle_data(Tensor graph_op) | ||||
| { | { | ||||
| var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | ||||
| return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data)); | |||||
| var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data)); | |||||
| return HandleData.Parser.ParseFrom(handle_str); | |||||
| } | } | ||||
| public static void dismantle_graph(Graph graph) | public static void dismantle_graph(Graph graph) | ||||