| @@ -23,11 +23,9 @@ namespace Tensorflow | |||
| var x = tf.placeholder(tf.float64, shape: (1024, 1024)); | |||
| var log = tf.log(x); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var ones = np.ones((1024, 1024), dtype: np.float64); | |||
| var o = sess.run(log, new FeedItem(x, ones)); | |||
| } | |||
| var sess = tf.Session(); | |||
| var ones = np.ones((1024, 1024), dtype: np.float64); | |||
| var o = sess.run(log, new FeedItem(x, ones)); | |||
| // Thread.Sleep(1); | |||
| } | |||
| @@ -25,15 +25,15 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Represents a TF_Buffer that can be passed to Tensorflow. | |||
| /// </summary> | |||
| public sealed class Buffer : IDisposable | |||
| public sealed class Buffer | |||
| { | |||
| public SafeBufferHandle Handle { get; } | |||
| SafeBufferHandle _handle; | |||
| /// <remarks> | |||
| /// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/> | |||
| /// </remarks> | |||
| private unsafe ref readonly TF_Buffer DangerousBuffer | |||
| => ref Unsafe.AsRef<TF_Buffer>(Handle.DangerousGetHandle().ToPointer()); | |||
| => ref Unsafe.AsRef<TF_Buffer>(_handle.DangerousGetHandle().ToPointer()); | |||
| /// <summary> | |||
| /// The memory block representing this buffer. | |||
| @@ -59,7 +59,7 @@ namespace Tensorflow | |||
| { | |||
| get | |||
| { | |||
| using (Handle.Lease()) | |||
| using (_handle.Lease()) | |||
| { | |||
| return DangerousBuffer.length; | |||
| } | |||
| @@ -67,13 +67,13 @@ namespace Tensorflow | |||
| } | |||
| public Buffer() | |||
| => Handle = TF_NewBuffer(); | |||
| => _handle = TF_NewBuffer(); | |||
| public Buffer(SafeBufferHandle handle) | |||
| => Handle = handle; | |||
| => _handle = handle; | |||
| public Buffer(byte[] data) | |||
| => Handle = _toBuffer(data); | |||
| => _handle = _toBuffer(data); | |||
| private static SafeBufferHandle _toBuffer(byte[] data) | |||
| { | |||
| @@ -92,7 +92,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public unsafe byte[] ToArray() | |||
| { | |||
| using (Handle.Lease()) | |||
| using (_handle.Lease()) | |||
| { | |||
| ref readonly TF_Buffer buffer = ref DangerousBuffer; | |||
| @@ -107,7 +107,12 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public void Dispose() | |||
| => Handle.Dispose(); | |||
| public override string ToString() | |||
| => $"0x{_handle.DangerousGetHandle():x16}"; | |||
| public static implicit operator SafeBufferHandle(Buffer buffer) | |||
| { | |||
| return buffer._handle; | |||
| } | |||
| } | |||
| } | |||
| @@ -11,7 +11,7 @@ public class CheckpointReader | |||
| Status status = new Status(); | |||
| VariableToDataTypeMap = new Dictionary<string, TF_DataType>(); | |||
| VariableToShapeMap = new Dictionary<string, Shape>(); | |||
| _handle = c_api.TF_NewCheckpointReader(filename, status.Handle); | |||
| _handle = c_api.TF_NewCheckpointReader(filename, status); | |||
| status.Check(true); | |||
| ReadAllShapeAndType(); | |||
| } | |||
| @@ -38,7 +38,7 @@ public class CheckpointReader | |||
| int num_dims = GetVariableNumDims(name); | |||
| long[] dims = new long[num_dims]; | |||
| Status status = new Status(); | |||
| c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); | |||
| c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status); | |||
| status.Check(true); | |||
| return new Shape(dims); | |||
| } | |||
| @@ -49,7 +49,7 @@ public class CheckpointReader | |||
| public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| Status status = new Status(); | |||
| var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); | |||
| var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status); | |||
| status.Check(true); | |||
| return new Tensor(tensor); | |||
| } | |||
| @@ -37,7 +37,7 @@ namespace Tensorflow.Contexts | |||
| public void log_device_placement(bool enable) | |||
| { | |||
| if (_handle != null) | |||
| c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status.Handle); | |||
| c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status); | |||
| _log_device_placement = enable; | |||
| // _thread_local_data.function_call_options = null; | |||
| } | |||
| @@ -60,15 +60,15 @@ namespace Tensorflow.Contexts | |||
| public PhysicalDevice[] list_physical_devices(string device_type = null) | |||
| { | |||
| using var opts = c_api.TFE_NewContextOptions(); | |||
| using var ctx = c_api.TFE_NewContext(opts, tf.Status.Handle); | |||
| using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status.Handle); | |||
| using var ctx = c_api.TFE_NewContext(opts, tf.Status); | |||
| using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status); | |||
| tf.Status.Check(true); | |||
| int num_devices = c_api.TF_DeviceListCount(devices); | |||
| var results = new List<PhysicalDevice>(); | |||
| for (int i = 0; i < num_devices; ++i) | |||
| { | |||
| var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status.Handle)); | |||
| var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status)); | |||
| tf.Status.Check(true); | |||
| if (dev_type.StartsWith("XLA")) | |||
| @@ -76,7 +76,7 @@ namespace Tensorflow.Contexts | |||
| if (device_type == null || dev_type == device_type) | |||
| { | |||
| var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status.Handle); | |||
| var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status); | |||
| tf.Status.Check(true); | |||
| results.Add(new PhysicalDevice | |||
| @@ -28,7 +28,7 @@ namespace Tensorflow.Contexts | |||
| /// <summary> | |||
| /// Environment in which eager operations execute. | |||
| /// </summary> | |||
| public sealed partial class Context : IDisposable | |||
| public sealed partial class Context | |||
| { | |||
| public const int GRAPH_MODE = 0; | |||
| public const int EAGER_MODE = 1; | |||
| @@ -41,15 +41,7 @@ namespace Tensorflow.Contexts | |||
| public FunctionCallOptions FunctionCallOptions { get; } | |||
| SafeContextHandle _handle; | |||
| public SafeContextHandle Handle | |||
| { | |||
| get | |||
| { | |||
| if (_handle == null) | |||
| ensure_initialized(); | |||
| return _handle; | |||
| } | |||
| } | |||
| int? _seed; | |||
| Random _rng; | |||
| @@ -59,6 +51,7 @@ namespace Tensorflow.Contexts | |||
| context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | |||
| initialized = false; | |||
| FunctionCallOptions = new FunctionCallOptions(); | |||
| ensure_initialized(); | |||
| } | |||
| /// <summary> | |||
| @@ -72,12 +65,12 @@ namespace Tensorflow.Contexts | |||
| Config = MergeConfig(); | |||
| FunctionCallOptions.Config = Config; | |||
| var config_str = Config.ToByteArray(); | |||
| using var opts = new ContextOptions(); | |||
| using var status = new Status(); | |||
| c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle); | |||
| var opts = new ContextOptions(); | |||
| var status = new Status(); | |||
| c_api.TFE_ContextOptionsSetConfig(opts, config_str, (ulong)config_str.Length, status); | |||
| status.Check(true); | |||
| c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts.Handle, _device_policy); | |||
| _handle = c_api.TFE_NewContext(opts.Handle, status.Handle); | |||
| c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); | |||
| _handle = c_api.TFE_NewContext(opts, status); | |||
| status.Check(true); | |||
| initialized = true; | |||
| } | |||
| @@ -178,10 +171,14 @@ namespace Tensorflow.Contexts | |||
| tf.Context.ensure_initialized(); | |||
| if (_handle != null) | |||
| { | |||
| c_api.TFE_ContextClearCaches(_handle); | |||
| } | |||
| } | |||
| public void Dispose() | |||
| => _handle.Dispose(); | |||
| public static implicit operator SafeContextHandle(Context ctx) | |||
| { | |||
| return ctx._handle; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,21 +14,21 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow.Contexts | |||
| namespace Tensorflow.Contexts; | |||
| public sealed class ContextOptions | |||
| { | |||
| public sealed class ContextOptions : IDisposable | |||
| { | |||
| public SafeContextOptionsHandle Handle { get; } | |||
| SafeContextOptionsHandle _handle { get; } | |||
| public ContextOptions() | |||
| { | |||
| Handle = c_api.TFE_NewContextOptions(); | |||
| } | |||
| public ContextOptions() | |||
| { | |||
| _handle = c_api.TFE_NewContextOptions(); | |||
| } | |||
| public void Dispose() | |||
| => Handle.Dispose(); | |||
| public static implicit operator SafeContextOptionsHandle(ContextOptions opt) | |||
| { | |||
| return opt._handle; | |||
| } | |||
| } | |||
| @@ -43,7 +43,7 @@ namespace Tensorflow.Eager | |||
| { | |||
| var status = tf.Status; | |||
| var op = GetOp(ctx, op_name, status); | |||
| c_api.TFE_OpSetDevice(op, device_name, status.Handle); | |||
| c_api.TFE_OpSetDevice(op, device_name, status); | |||
| if (status.ok()) | |||
| { | |||
| for (int i = 0; i < inputs.Length; ++i) | |||
| @@ -54,7 +54,7 @@ namespace Tensorflow.Eager | |||
| Tensor nd => nd.EagerTensorHandle, | |||
| _ => throw new NotImplementedException("Eager tensor handle has not been allocated.") | |||
| }; | |||
| c_api.TFE_OpAddInput(op, tensor_handle, status.Handle); | |||
| c_api.TFE_OpAddInput(op, tensor_handle, status); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| @@ -64,7 +64,7 @@ namespace Tensorflow.Eager | |||
| var outputs = new SafeEagerTensorHandle[num_outputs]; | |||
| if (status.ok()) | |||
| { | |||
| c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle); | |||
| c_api.TFE_Execute(op, outputs, out num_outputs, status); | |||
| status.Check(true); | |||
| } | |||
| return outputs.Select(x => new EagerTensor(x)).ToArray(); | |||
| @@ -104,7 +104,7 @@ namespace Tensorflow.Eager | |||
| var eager_tensor = ops.convert_to_tensor(fast_input_array[j]); | |||
| attr_values[j] = eager_tensor.dtype; | |||
| c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle); | |||
| c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status); | |||
| if (op_exec_info.run_callbacks) | |||
| { | |||
| @@ -142,7 +142,7 @@ namespace Tensorflow.Eager | |||
| } | |||
| var retVals = new SafeEagerTensorHandle[num_retvals]; | |||
| c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle); | |||
| c_api.TFE_Execute(op, retVals, out num_retvals, status); | |||
| status.Check(true); | |||
| var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); | |||
| @@ -160,10 +160,10 @@ namespace Tensorflow.Eager | |||
| SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | |||
| { | |||
| if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) | |||
| c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle); | |||
| c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status); | |||
| else | |||
| { | |||
| op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); | |||
| op = c_api.TFE_NewOp(ctx, op_or_function_name, status); | |||
| thread_local_eager_operation_map[op_or_function_name] = op; | |||
| } | |||
| @@ -219,7 +219,7 @@ namespace Tensorflow.Eager | |||
| flattened_attrs.Add(dtype); | |||
| } | |||
| c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status.Handle); | |||
| c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status); | |||
| status.Check(true); | |||
| return true; | |||
| @@ -235,7 +235,7 @@ namespace Tensorflow.Eager | |||
| var value = attrs[i + 1]; | |||
| byte is_list = 0; | |||
| var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status.Handle); | |||
| var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status); | |||
| if (!status.ok()) return; | |||
| if (is_list != 0) | |||
| SetOpAttrList(tf.Context, op, key, value as object[], type, null, status); | |||
| @@ -264,7 +264,7 @@ namespace Tensorflow.Eager | |||
| Status status) | |||
| { | |||
| byte is_list = 0; | |||
| var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status.Handle); | |||
| var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status); | |||
| if (status.Code != TF_Code.TF_OK) return; | |||
| if (attr_value == null) | |||
| @@ -305,7 +305,7 @@ namespace Tensorflow.Eager | |||
| tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long)); | |||
| } | |||
| c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle); | |||
| c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status); | |||
| Array.ForEach(dims, x => Marshal.FreeHGlobal(x)); | |||
| } | |||
| else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2) | |||
| @@ -353,7 +353,7 @@ namespace Tensorflow.Eager | |||
| break; | |||
| case TF_AttrType.TF_ATTR_SHAPE: | |||
| var dims = (value as long[]).ToArray(); | |||
| c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle); | |||
| c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | |||
| status.Check(true); | |||
| break; | |||
| case TF_AttrType.TF_ATTR_FUNC: | |||
| @@ -54,7 +54,7 @@ namespace Tensorflow.Eager | |||
| void NewEagerTensorHandle(SafeTensorHandle h) | |||
| { | |||
| _id = ops.uid(); | |||
| _eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle); | |||
| _eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status); | |||
| #if TRACK_TENSOR_LIFE | |||
| Console.WriteLine($"New EagerTensor {_eagerTensorHandle}"); | |||
| #endif | |||
| @@ -65,7 +65,7 @@ namespace Tensorflow.Eager | |||
| { | |||
| if (_handle != null) | |||
| return; | |||
| _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); | |||
| _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status); | |||
| tf.Status.Check(true); | |||
| } | |||
| @@ -24,10 +24,10 @@ namespace Tensorflow.Eager | |||
| } | |||
| } | |||
| public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); | |||
| public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status)); | |||
| public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); | |||
| public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); | |||
| public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status); | |||
| public override ulong bytesize | |||
| { | |||
| @@ -49,9 +49,9 @@ namespace Tensorflow.Eager | |||
| protected override Shape GetShapeInternal() | |||
| { | |||
| var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; | |||
| var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status)]; | |||
| for (int i = 0; i < dims.Length; i++) | |||
| dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status.Handle); | |||
| dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status); | |||
| return dims; | |||
| } | |||
| @@ -64,15 +64,15 @@ namespace Tensorflow.Eager | |||
| public static int GetRank(IntPtr handle) | |||
| { | |||
| var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | |||
| return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status.Handle); | |||
| return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status); | |||
| } | |||
| public static int[] GetDims(IntPtr handle) | |||
| { | |||
| var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | |||
| var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status.Handle)]; | |||
| var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status)]; | |||
| for (int i = 0; i < dims.Length; i++) | |||
| dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle); | |||
| dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status); | |||
| return dims; | |||
| } | |||
| @@ -114,7 +114,7 @@ namespace Tensorflow | |||
| /// <param name="function"></param> | |||
| /// <param name="status"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, IntPtr function, SafeStatusHandle status); | |||
| public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, SafeFuncGraphHandle function, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Removes a function from the context. Once removed, you can no longer | |||
| @@ -56,15 +56,14 @@ namespace Tensorflow | |||
| TF_ImportGraphDefResults results = null; | |||
| var bytes = graph_def.ToByteString().ToArray(); | |||
| using (var buffer = c_api_util.tf_buffer(bytes)) | |||
| using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) | |||
| using (var status = new Status()) | |||
| { | |||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
| // need to create a class ImportGraphDefWithResults with IDisposal | |||
| results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle)); | |||
| status.Check(true); | |||
| } | |||
| var buffer = c_api_util.tf_buffer(bytes); | |||
| var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||
| var status = new Status(); | |||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
| // need to create a class ImportGraphDefWithResults with IDisposal | |||
| results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | |||
| status.Check(true); | |||
| _ProcessNewOps(graph); | |||
| @@ -116,13 +115,13 @@ namespace Tensorflow | |||
| Dictionary<string, Tensor> input_map, | |||
| string[] return_elements) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix); | |||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | |||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||
| foreach (var input in input_map) | |||
| { | |||
| var (src_name, src_index) = _ParseTensorName(input.Key); | |||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Handle, src_name, src_index, input.Value._as_tf_output()); | |||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output()); | |||
| } | |||
| if (return_elements == null) | |||
| @@ -133,11 +132,11 @@ namespace Tensorflow | |||
| if (name.Contains(":")) | |||
| { | |||
| var (op_name, index) = _ParseTensorName(name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||
| } | |||
| else | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||
| } | |||
| } | |||
| @@ -33,7 +33,7 @@ namespace Tensorflow | |||
| if (_registered_ops.Count > 0) | |||
| return _registered_ops; | |||
| using var buffer = new Buffer(c_api.TF_GetAllOpList()); | |||
| var buffer = new Buffer(c_api.TF_GetAllOpList()); | |||
| var op_list = OpList.Parser.ParseFrom(buffer.ToArray()); | |||
| foreach (var op_def in op_list.Op) | |||
| _registered_ops[op_def.Name] = op_def; | |||
| @@ -56,8 +56,8 @@ namespace Tensorflow.Framework | |||
| if (pred_value is null) | |||
| { | |||
| var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray(); | |||
| var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status.Handle); | |||
| if (!evaluated || c_api.TF_GetCode(tf.Status.Handle) != TF_Code.TF_OK) | |||
| var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status); | |||
| if (!evaluated || c_api.TF_GetCode(tf.Status) != TF_Code.TF_OK) | |||
| return null; | |||
| else | |||
| throw new NotImplementedException(""); | |||
| @@ -34,10 +34,10 @@ namespace Tensorflow | |||
| /// <param name="output_func_def"></param> | |||
| /// <param name="status"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_FunctionToFunctionDef(IntPtr func, SafeBufferHandle output_func_def, SafeStatusHandle status); | |||
| public static extern void TF_FunctionToFunctionDef(SafeFuncGraphHandle func, SafeBufferHandle output_func_def, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name, | |||
| public static extern SafeFuncGraphHandle TF_GraphToFunction(SafeGraphHandle fn_body, string fn_name, | |||
| bool append_hash_to_fn_name, | |||
| int num_opers, IntPtr[] opers, | |||
| int ninputs, TF_Output[] inputs, | |||
| @@ -48,12 +48,12 @@ namespace Tensorflow | |||
| SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||
| public static extern IntPtr TF_FunctionSetAttrValueProto(SafeFuncGraphHandle func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_FunctionName(IntPtr func); | |||
| public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphCopyFunction(IntPtr g, IntPtr func, IntPtr grad, SafeStatusHandle status); | |||
| public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status); | |||
| } | |||
| } | |||
| @@ -37,7 +37,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <param name="dy">TF_Output*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_AddGradientsWithPrefix(IntPtr g, string prefix, TF_Output[] y, int ny, | |||
| public static extern void TF_AddGradientsWithPrefix(SafeGraphHandle g, string prefix, TF_Output[] y, int ny, | |||
| TF_Output[] x, int nx, TF_Output[] dx, SafeStatusHandle status, IntPtr[] dy); | |||
| } | |||
| } | |||
| @@ -22,21 +22,19 @@ namespace Tensorflow | |||
| var inputs_string = string.Join(",", inputs); | |||
| var outputs_string = string.Join(",", outputs); | |||
| var transforms_string = string.Join(" ", transforms); | |||
| using (var status = new Status()) | |||
| { | |||
| var buffer = new Buffer(); | |||
| var len = c_api.TransformGraphWithStringInputs(input_graph_def_string, | |||
| input_graph_def_string.Length, | |||
| inputs_string, | |||
| outputs_string, | |||
| transforms_string, | |||
| buffer.Handle, | |||
| status.Handle); | |||
| var status = new Status(); | |||
| var buffer = new Buffer(); | |||
| var len = c_api.TransformGraphWithStringInputs(input_graph_def_string, | |||
| input_graph_def_string.Length, | |||
| inputs_string, | |||
| outputs_string, | |||
| transforms_string, | |||
| buffer, | |||
| status); | |||
| status.Check(false); | |||
| var bytes = buffer.ToArray(); | |||
| return GraphDef.Parser.ParseFrom(bytes); | |||
| } | |||
| status.Check(false); | |||
| var bytes = buffer.ToArray(); | |||
| return GraphDef.Parser.ParseFrom(bytes); | |||
| } | |||
| } | |||
| } | |||
| @@ -37,11 +37,9 @@ namespace Tensorflow.Graphs | |||
| 1); | |||
| return result[0]; | |||
| } | |||
| using (var s = tf.Session(input.graph)) | |||
| { | |||
| var output = func(input); | |||
| return output; | |||
| } | |||
| var s = tf.Session(input.graph); | |||
| var output = func(input); | |||
| return output; | |||
| }; | |||
| } | |||
| @@ -75,12 +73,10 @@ namespace Tensorflow.Graphs | |||
| 1); | |||
| return result[0]; | |||
| } | |||
| using (var s = tf.Session(a.graph)) | |||
| { | |||
| Debug.Assert(a.graph == b.graph); | |||
| var output = func(a, b); | |||
| return output; | |||
| } | |||
| var s = tf.Session(a.graph); | |||
| Debug.Assert(a.graph == b.graph); | |||
| var output = func(a, b); | |||
| return output; | |||
| }; | |||
| } | |||
| } | |||
| @@ -1,258 +1,252 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Exceptions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Graphs | |||
| namespace Tensorflow.Graphs; | |||
| /// <summary> | |||
| /// Graph representing a function body. | |||
| /// </summary> | |||
| public class FuncGraph : Graph, IDisposable | |||
| { | |||
| SafeFuncGraphHandle _func_graph_handle; | |||
| public string FuncName => _graph_key; | |||
| public Tensors Inputs { get; set; } = new Tensors(); | |||
| public Tensors Outputs { get; set; } = new Tensors(); | |||
| public Dictionary<string, string> Attrs { get; set; } | |||
| Dictionary<long, (Tensor, Tensor)> _captures | |||
| = new Dictionary<long, (Tensor, Tensor)>(); | |||
| public Tensor[] external_captures | |||
| => _captures.Select(x => x.Value.Item1).ToArray(); | |||
| public (Tensor, Tensor)[] captures | |||
| => _captures.Values.Select(x => x).ToArray(); | |||
| public Tensor[] internal_captures | |||
| => _captures.Select(x => x.Value.Item2).ToArray(); | |||
| public Tensor[] captured_inputs | |||
| => external_captures; | |||
| /// <summary> | |||
| /// Graph representing a function body. | |||
| /// Construct a new FuncGraph. | |||
| /// </summary> | |||
| public class FuncGraph : Graph | |||
| public FuncGraph(string name) : base() | |||
| { | |||
| IntPtr _func_graph_handle; | |||
| public string FuncName => _graph_key; | |||
| public Tensors Inputs { get; set; } = new Tensors(); | |||
| public Tensors Outputs { get; set; } = new Tensors(); | |||
| public Dictionary<string, string> Attrs { get; set; } | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = name; | |||
| building_function = true; | |||
| } | |||
| Dictionary<long, (Tensor, Tensor)> _captures | |||
| = new Dictionary<long, (Tensor, Tensor)>(); | |||
| public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, string> attrs) : base() | |||
| { | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = name; | |||
| building_function = true; | |||
| Attrs = attrs; | |||
| // Will to test if FuncGraph has memory leak | |||
| // c_api.TF_DeleteGraph(_handle); | |||
| _handle = handle; | |||
| } | |||
| public Tensor[] external_captures | |||
| => _captures.Select(x => x.Value.Item1).ToArray(); | |||
| public (Tensor, Tensor)[] captures | |||
| => _captures.Values.Select(x => x).ToArray(); | |||
| public void ToGraph(Operation[] opers, | |||
| Tensor[] inputs, Tensor[] outputs, | |||
| string[] output_names) | |||
| { | |||
| var status = new Status(); | |||
| _func_graph_handle = c_api.TF_GraphToFunction(_handle, | |||
| _graph_key, | |||
| false, | |||
| opers.Length, | |||
| opers.Select(x => (IntPtr)x).ToArray(), | |||
| inputs.Length, | |||
| inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
| outputs.Length, | |||
| outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
| output_names, | |||
| IntPtr.Zero, | |||
| null, | |||
| status); | |||
| status.Check(true); | |||
| SetAttrs(); | |||
| // c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle); | |||
| // status.Check(true); | |||
| c_api.TFE_ContextAddFunction(tf.Context, _func_graph_handle, status); | |||
| status.Check(true); | |||
| _graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle)); | |||
| Inputs = inputs; | |||
| // mark_as_return | |||
| Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); | |||
| } | |||
| public Tensor[] internal_captures | |||
| => _captures.Select(x => x.Value.Item2).ToArray(); | |||
| public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true) | |||
| { | |||
| foreach(var (i, inp) in enumerate(inputs)) | |||
| inputs[i] = capture(inp); | |||
| public Tensor[] captured_inputs | |||
| => external_captures; | |||
| return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | |||
| } | |||
| /// <summary> | |||
| /// Construct a new FuncGraph. | |||
| /// </summary> | |||
| public FuncGraph(string name) : base() | |||
| const int _EAGER_CONST_THRESHOLD = 128; | |||
| public Tensor capture(Tensor tensor, string name = null, Shape shape = null) | |||
| { | |||
| if(tensor is EagerTensor) | |||
| { | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = name; | |||
| building_function = true; | |||
| if (name == null) | |||
| name = ops.uid().ToString(); | |||
| // Small EagerTensors are captured with Const ops | |||
| if (dtypes.is_value_dtype(tensor.dtype) | |||
| && (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD)) | |||
| return capture_eager_tensor(tensor, name); | |||
| // Large EagerTensors and resources are captured with Placeholder ops | |||
| return _capture_helper(tensor, name, shape: shape); | |||
| } | |||
| public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | |||
| if(tensor.graph != this) | |||
| { | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = name; | |||
| building_function = true; | |||
| Attrs = attrs; | |||
| // Will to test if FuncGraph has memory leak | |||
| // c_api.TF_DeleteGraph(_handle); | |||
| _handle = handle; | |||
| if (name == null) | |||
| name = tensor.op.name; | |||
| var inner_graph = tensor.graph; | |||
| while(inner_graph != null && inner_graph is FuncGraph inner_func_graph) | |||
| { | |||
| if (inner_graph == this) | |||
| throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" + | |||
| " in another function or code block. Use return values," + | |||
| " explicit Python locals or TensorFlow collections to access" + | |||
| $" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}."); | |||
| inner_graph = inner_func_graph.outer_graph; | |||
| } | |||
| return _capture_helper(tensor, name); | |||
| } | |||
| public void ToGraph(Operation[] opers, | |||
| Tensor[] inputs, Tensor[] outputs, | |||
| string[] output_names) | |||
| return tensor; | |||
| } | |||
| Tensor capture_eager_tensor(Tensor tensor, string name) | |||
| { | |||
| Tensor graph_const = null; | |||
| if (!_captures.ContainsKey(tensor.Id)) | |||
| { | |||
| var status = new Status(); | |||
| _func_graph_handle = c_api.TF_GraphToFunction(_handle, | |||
| _graph_key, | |||
| false, | |||
| opers.Length, | |||
| opers.Select(x => (IntPtr)x).ToArray(), | |||
| inputs.Length, | |||
| inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
| outputs.Length, | |||
| outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
| output_names == null || output_names.Length == 0 ? null : output_names, | |||
| IntPtr.Zero, | |||
| null, | |||
| status.Handle); | |||
| status.Check(true); | |||
| SetAttrs(); | |||
| // c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle); | |||
| // status.Check(true); | |||
| c_api.TFE_ContextAddFunction(tf.Context.Handle, _func_graph_handle, status.Handle); | |||
| status.Check(true); | |||
| _graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle)); | |||
| Inputs = inputs; | |||
| // mark_as_return | |||
| Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); | |||
| graph_const = tf_with(ops.control_dependencies(null), ctl | |||
| => constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name)); | |||
| add_capture(tensor, graph_const); | |||
| } | |||
| public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true) | |||
| else | |||
| { | |||
| foreach(var (i, inp) in enumerate(inputs)) | |||
| inputs[i] = capture(inp); | |||
| return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | |||
| graph_const = _captures[tensor.Id].Item2; | |||
| } | |||
| const int _EAGER_CONST_THRESHOLD = 128; | |||
| public Tensor capture(Tensor tensor, string name = null, Shape shape = null) | |||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
| { | |||
| if(tensor is EagerTensor) | |||
| { | |||
| if (name == null) | |||
| name = ops.uid().ToString(); | |||
| // Small EagerTensors are captured with Const ops | |||
| if (dtypes.is_value_dtype(tensor.dtype) | |||
| && (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD)) | |||
| return capture_eager_tensor(tensor, name); | |||
| return output_grads; | |||
| }; | |||
| // Large EagerTensors and resources are captured with Placeholder ops | |||
| return _capture_helper(tensor, name, shape: shape); | |||
| } | |||
| tf.Runner.RecordGradient("captured_value", | |||
| new[] { graph_const }, null, | |||
| new[] { tensor }, | |||
| getBackwardFunction: _backward_function_wrapper | |||
| /*getForwardFunction: forward_function*/); | |||
| if(tensor.graph != this) | |||
| { | |||
| if (name == null) | |||
| name = tensor.op.name; | |||
| var inner_graph = tensor.graph; | |||
| while(inner_graph != null && inner_graph is FuncGraph inner_func_graph) | |||
| { | |||
| if (inner_graph == this) | |||
| throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" + | |||
| " in another function or code block. Use return values," + | |||
| " explicit Python locals or TensorFlow collections to access" + | |||
| $" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}."); | |||
| inner_graph = inner_func_graph.outer_graph; | |||
| } | |||
| return _capture_helper(tensor, name); | |||
| } | |||
| return graph_const; | |||
| } | |||
| return tensor; | |||
| Tensor _capture_helper(Tensor tensor, string name, Shape shape = null) | |||
| { | |||
| Tensor placeholder = null; | |||
| if (!_captures.ContainsKey(tensor.Id)) | |||
| { | |||
| placeholder = _create_substitute_placeholder(tensor, | |||
| name: name, | |||
| dtype: tensor.dtype, | |||
| shape: shape); | |||
| add_capture(tensor, placeholder); | |||
| } | |||
| Tensor capture_eager_tensor(Tensor tensor, string name) | |||
| else | |||
| { | |||
| Tensor graph_const = null; | |||
| if (!_captures.ContainsKey(tensor.Id)) | |||
| { | |||
| graph_const = tf_with(ops.control_dependencies(null), ctl | |||
| => constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name)); | |||
| add_capture(tensor, graph_const); | |||
| } | |||
| else | |||
| { | |||
| graph_const = _captures[tensor.Id].Item2; | |||
| } | |||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
| { | |||
| return output_grads; | |||
| }; | |||
| tf.Runner.RecordGradient("captured_value", | |||
| new[] { graph_const }, null, | |||
| new[] { tensor }, | |||
| getBackwardFunction: _backward_function_wrapper | |||
| /*getForwardFunction: forward_function*/); | |||
| return graph_const; | |||
| placeholder = _captures[tensor.Id].Item2; | |||
| } | |||
| Tensor _capture_helper(Tensor tensor, string name, Shape shape = null) | |||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
| { | |||
| Tensor placeholder = null; | |||
| if (!_captures.ContainsKey(tensor.Id)) | |||
| { | |||
| placeholder = _create_substitute_placeholder(tensor, | |||
| name: name, | |||
| dtype: tensor.dtype, | |||
| shape: shape); | |||
| add_capture(tensor, placeholder); | |||
| } | |||
| else | |||
| { | |||
| placeholder = _captures[tensor.Id].Item2; | |||
| } | |||
| return output_grads; | |||
| }; | |||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
| { | |||
| return output_grads; | |||
| }; | |||
| tf.Runner.RecordGradient("captured_value", | |||
| new[] { placeholder }, null, | |||
| new[] { tensor }, | |||
| getBackwardFunction: _backward_function_wrapper | |||
| /*getForwardFunction: forward_function*/); | |||
| tf.Runner.RecordGradient("captured_value", | |||
| new[] { placeholder }, null, | |||
| new[] { tensor }, | |||
| getBackwardFunction: _backward_function_wrapper | |||
| /*getForwardFunction: forward_function*/); | |||
| return placeholder; | |||
| } | |||
| return placeholder; | |||
| } | |||
| void add_capture(Tensor tensor, Tensor placeholder) | |||
| { | |||
| _captures.Add(tensor.Id, (tensor, placeholder)); | |||
| Inputs.Add(placeholder); | |||
| } | |||
| void add_capture(Tensor tensor, Tensor placeholder) | |||
| { | |||
| _captures.Add(tensor.Id, (tensor, placeholder)); | |||
| Inputs.Add(placeholder); | |||
| } | |||
| Tensor _create_substitute_placeholder(Tensor value, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| Shape shape = null) | |||
| { | |||
| if (shape is null) | |||
| shape = value.shape; | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = value.dtype; | |||
| var placeholder = tf_with(ops.control_dependencies(null), ctl | |||
| => array_ops.placeholder(dtype, shape: shape, name: name)); | |||
| // custom_gradient.copy_handle_data(value, placeholder) | |||
| return placeholder; | |||
| } | |||
| Tensor _create_substitute_placeholder(Tensor value, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| Shape shape = null) | |||
| { | |||
| if (shape is null) | |||
| shape = value.shape; | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = value.dtype; | |||
| var placeholder = tf_with(ops.control_dependencies(null), ctl | |||
| => array_ops.placeholder(dtype, shape: shape, name: name)); | |||
| // custom_gradient.copy_handle_data(value, placeholder) | |||
| return placeholder; | |||
| } | |||
| void SetAttrs() | |||
| { | |||
| if (Attrs == null) | |||
| return; | |||
| void SetAttrs() | |||
| foreach (var (_name, attr_value) in enumerate(Attrs)) | |||
| { | |||
| if (Attrs == null) | |||
| return; | |||
| foreach (var (_name, attr_value) in enumerate(Attrs)) | |||
| var serialized = new AttrValue | |||
| { | |||
| var serialized = new AttrValue | |||
| { | |||
| S = ByteString.CopyFromUtf8(attr_value) | |||
| }.ToByteArray(); | |||
| c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status.Handle); | |||
| tf.Status.Check(true); | |||
| } | |||
| S = ByteString.CopyFromUtf8(attr_value) | |||
| }.ToByteArray(); | |||
| c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); | |||
| tf.Status.Check(true); | |||
| } | |||
| } | |||
| public override Graph as_default() | |||
| { | |||
| tf.Context.graph_mode(isFunc: true); | |||
| ops.set_default_graph(this); | |||
| return this; | |||
| } | |||
| public override Graph as_default() | |||
| { | |||
| tf.Context.graph_mode(isFunc: true); | |||
| ops.set_default_graph(this); | |||
| return this; | |||
| } | |||
| public override void Exit() | |||
| { | |||
| tf.Context.restore_mode(); | |||
| ops.pop_graph(); | |||
| } | |||
| public override void Exit() | |||
| { | |||
| tf.Context.restore_mode(); | |||
| ops.pop_graph(); | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| c_api.TFE_ContextRemoveFunction(tf.Context.Handle, _graph_key, tf.Status.Handle); | |||
| c_api.TF_DeleteFunction(_func_graph_handle); | |||
| base.DisposeUnmanagedResources(handle); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); | |||
| } | |||
| } | |||
| @@ -24,7 +24,7 @@ namespace Tensorflow | |||
| public Buffer ToGraphDef(Status s) | |||
| { | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(_handle, buffer.Handle, s.Handle); | |||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
| s.Check(true); | |||
| return buffer; | |||
| @@ -33,14 +33,12 @@ namespace Tensorflow | |||
| private GraphDef _as_graph_def(bool add_shapes = false) | |||
| { | |||
| GraphDef def; | |||
| using (var status = new Status()) | |||
| using (var buffer = ToGraphDef(status)) | |||
| { | |||
| status.Check(true); | |||
| // limit size to 250M, recursion to max 100 | |||
| var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100); | |||
| def = GraphDef.Parser.ParseFrom(inputStream); | |||
| } | |||
| var status = new Status(); | |||
| var buffer = ToGraphDef(status); | |||
| status.Check(true); | |||
| // limit size to 250M, recursion to max 100 | |||
| var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100); | |||
| def = GraphDef.Parser.ParseFrom(inputStream); | |||
| // Strip the experimental library field iff it's empty. | |||
| // if(def.Library.Function.Count == 0) | |||
| @@ -29,7 +29,7 @@ namespace Tensorflow | |||
| int size = Marshal.SizeOf<TF_Output>(); | |||
| var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | |||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def.Handle, opts.Handle, return_output_handle, num_return_outputs, s.Handle); | |||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | |||
| var tf_output_ptr = (TF_Output*)return_output_handle; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| @@ -48,15 +48,14 @@ namespace Tensorflow | |||
| public bool Import(byte[] bytes, string prefix = "") | |||
| { | |||
| using (var opts = new ImportGraphDefOptions()) | |||
| using (var status = new Status()) | |||
| using (var graph_def = new Buffer(bytes)) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix); | |||
| c_api.TF_GraphImportGraphDef(_handle, graph_def.Handle, opts.Handle, status.Handle); | |||
| status.Check(true); | |||
| return status.Code == TF_Code.TF_OK; | |||
| } | |||
| var opts = new ImportGraphDefOptions(); | |||
| var status = new Status(); | |||
| var graph_def = new Buffer(bytes); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); | |||
| c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); | |||
| status.Check(true); | |||
| return status.Code == TF_Code.TF_OK; | |||
| } | |||
| public Graph ImportGraphDef(string file_path, string name = null) | |||
| @@ -75,9 +75,9 @@ namespace Tensorflow | |||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
| /// </summary> | |||
| /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | |||
| public partial class Graph : DisposableObject | |||
| , IEnumerable<Operation> | |||
| public partial class Graph : IEnumerable<Operation> | |||
| { | |||
| protected new SafeGraphHandle _handle; | |||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||
| public Dictionary<string, ITensorOrOperation> _nodes_by_name; | |||
| private Dictionary<string, int> _names_in_use; | |||
| @@ -130,15 +130,6 @@ namespace Tensorflow | |||
| _graph_key = $"graph-{ops.GraphUniqueId()}/"; | |||
| } | |||
| public Graph(IntPtr handle) | |||
| { | |||
| _handle = handle; | |||
| _nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||
| _nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||
| _names_in_use = new Dictionary<string, int>(); | |||
| _graph_key = $"grap-{ops.GraphUniqueId()}/"; | |||
| } | |||
| public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | |||
| { | |||
| return _as_graph_element_locked(obj, allow_tensor, allow_operation); | |||
| @@ -486,16 +477,6 @@ namespace Tensorflow | |||
| _unfetchable_ops.Add(op); | |||
| } | |||
| protected override void DisposeManagedResources() | |||
| { | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| c_api.TF_DeleteGraph(handle); | |||
| } | |||
| public Tensor get_tensor_by_tf_output(TF_Output tf_output) | |||
| { | |||
| var op = _get_operation_by_tf_operation(tf_output.oper); | |||
| @@ -517,14 +498,14 @@ namespace Tensorflow | |||
| public Shape GetTensorShape(TF_Output output) | |||
| { | |||
| var status = tf.Status; | |||
| var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status.Handle); | |||
| var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); | |||
| status.Check(); | |||
| if (ndim == -1) | |||
| return Shape.Null; | |||
| var dims = new long[ndim]; | |||
| c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status.Handle); | |||
| c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status); | |||
| status.Check(); | |||
| return new Shape(dims.Select(x => (int)x).ToArray()); | |||
| @@ -539,7 +520,7 @@ namespace Tensorflow | |||
| string debugString = string.Empty; | |||
| public override string ToString() | |||
| { | |||
| return $"{graph_key}, 0x{_handle.ToString("x16")}"; | |||
| return $"{graph_key}, 0x{_handle.DangerousGetHandle().ToString("x16")}"; | |||
| /*if (string.IsNullOrEmpty(debugString)) | |||
| { | |||
| int len = 0; | |||
| @@ -558,7 +539,7 @@ namespace Tensorflow | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| => throw new NotImplementedException(); | |||
| public static implicit operator IntPtr(Graph graph) | |||
| public static implicit operator SafeGraphHandle(Graph graph) | |||
| { | |||
| return graph._handle; | |||
| } | |||
| @@ -14,28 +14,27 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| namespace Tensorflow; | |||
| namespace Tensorflow | |||
| public sealed class ImportGraphDefOptions | |||
| { | |||
| public sealed class ImportGraphDefOptions : IDisposable | |||
| { | |||
| public SafeImportGraphDefOptionsHandle Handle { get; } | |||
| SafeImportGraphDefOptionsHandle _handle { get; } | |||
| public int NumReturnOutputs | |||
| => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(Handle); | |||
| public int NumReturnOutputs | |||
| => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||
| public ImportGraphDefOptions() | |||
| { | |||
| Handle = c_api.TF_NewImportGraphDefOptions(); | |||
| } | |||
| public ImportGraphDefOptions() | |||
| { | |||
| _handle = c_api.TF_NewImportGraphDefOptions(); | |||
| } | |||
| public void AddReturnOutput(string name, int index) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(Handle, name, index); | |||
| } | |||
| public void AddReturnOutput(string name, int index) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
| } | |||
| public void Dispose() | |||
| => Handle.Dispose(); | |||
| public static implicit operator SafeImportGraphDefOptionsHandle(ImportGraphDefOptions opt) | |||
| { | |||
| return opt._handle; | |||
| } | |||
| } | |||
| @@ -0,0 +1,22 @@ | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow; | |||
| public sealed class SafeFuncGraphHandle : SafeTensorflowHandle | |||
| { | |||
| private SafeFuncGraphHandle() | |||
| { | |||
| } | |||
| public SafeFuncGraphHandle(IntPtr handle) | |||
| : base(handle) | |||
| { | |||
| } | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| c_api.TF_DeleteFunction(handle); | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| } | |||
| @@ -0,0 +1,22 @@ | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow; | |||
| public sealed class SafeGraphHandle : SafeTensorflowHandle | |||
| { | |||
| private SafeGraphHandle() | |||
| { | |||
| } | |||
| public SafeGraphHandle(IntPtr handle) | |||
| : base(handle) | |||
| { | |||
| } | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| c_api.TF_DeleteGraph(handle); | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| } | |||
| @@ -60,7 +60,7 @@ namespace Tensorflow | |||
| /// <param name="num_dims"></param> | |||
| /// <param name="status"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||
| public static extern void TF_GraphGetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Import the graph serialized in `graph_def` into `graph`. | |||
| @@ -78,7 +78,7 @@ namespace Tensorflow | |||
| /// <param name="num_return_outputs">int</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); | |||
| public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | |||
| @@ -92,7 +92,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns>TF_ImportGraphDefResults*</returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
| public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Import the graph serialized in `graph_def` into `graph`. | |||
| @@ -102,7 +102,7 @@ namespace Tensorflow | |||
| /// <param name="options">TF_ImportGraphDefOptions*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphImportGraphDef(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
| public static extern void TF_GraphImportGraphDef(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Iterate through the operations of a graph. | |||
| @@ -111,7 +111,7 @@ namespace Tensorflow | |||
| /// <param name="pos"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_GraphNextOperation(IntPtr graph, ref uint pos); | |||
| public static extern IntPtr TF_GraphNextOperation(SafeGraphHandle graph, ref uint pos); | |||
| /// <summary> | |||
| /// Returns the operation in the graph with `oper_name`. Returns nullptr if | |||
| @@ -121,14 +121,14 @@ namespace Tensorflow | |||
| /// <param name="oper_name"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_GraphOperationByName(IntPtr graph, string oper_name); | |||
| public static extern IntPtr TF_GraphOperationByName(SafeGraphHandle graph, string oper_name); | |||
| /// <summary> | |||
| /// Sets the shape of the Tensor referenced by `output` in `graph` to | |||
| /// the shape described by `dims` and `num_dims`. | |||
| /// </summary> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||
| public static extern void TF_GraphSetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Write out a serialized representation of `graph` (as a GraphDef protocol | |||
| @@ -138,7 +138,7 @@ namespace Tensorflow | |||
| /// <param name="output_graph_def">TF_Buffer*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); | |||
| public static extern void TF_GraphToGraphDef(SafeGraphHandle graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Returns the number of dimensions of the Tensor referenced by `output` | |||
| @@ -151,7 +151,7 @@ namespace Tensorflow | |||
| /// <param name="status"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, SafeStatusHandle status); | |||
| public static extern int TF_GraphGetTensorNumDims(SafeGraphHandle graph, TF_Output output, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Cause the imported graph to have a control dependency on `oper`. `oper` | |||
| @@ -287,12 +287,12 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | |||
| public static extern SafeSessionHandle TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | |||
| string export_dir, string[] tags, int tags_len, | |||
| IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status); | |||
| SafeGraphHandle graph, IntPtr meta_graph_def, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_NewGraph(); | |||
| public static extern SafeGraphHandle TF_NewGraph(); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions(); | |||
| @@ -334,6 +334,6 @@ namespace Tensorflow | |||
| /// <param name="status"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern bool TF_TryEvaluateConstant(IntPtr graph, TF_Output output, IntPtr[] result, SafeStatusHandle status); | |||
| public static extern bool TF_TryEvaluateConstant(SafeGraphHandle graph, TF_Output output, IntPtr[] result, SafeStatusHandle status); | |||
| } | |||
| } | |||
| @@ -61,7 +61,7 @@ namespace Tensorflow.NumPy | |||
| { | |||
| if (_handle is not null) | |||
| { | |||
| _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||
| _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status); | |||
| } | |||
| } | |||
| } | |||
| @@ -31,7 +31,7 @@ namespace Tensorflow | |||
| public int InputListLength(string name) | |||
| { | |||
| int num = 0; | |||
| num = c_api.TF_OperationInputListLength(_handle, name, tf.Status.Handle); | |||
| num = c_api.TF_OperationInputListLength(_handle, name, tf.Status); | |||
| tf.Status.Check(true); | |||
| return num; | |||
| } | |||
| @@ -28,7 +28,7 @@ namespace Tensorflow | |||
| public int OutputListLength(string name) | |||
| { | |||
| int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status.Handle); | |||
| int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status); | |||
| tf.Status.Check(true); | |||
| return num; | |||
| @@ -187,8 +187,8 @@ namespace Tensorflow | |||
| if (tf.executing_eagerly()) | |||
| return (T[])get_attr(name); | |||
| using var buf = new Buffer(); | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||
| var buf = new Buffer(); | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); | |||
| tf.Status.Check(true); | |||
| var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
| @@ -210,8 +210,8 @@ namespace Tensorflow | |||
| public virtual object get_attr(string name) | |||
| { | |||
| using var buf = new Buffer(); | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||
| var buf = new Buffer(); | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); | |||
| tf.Status.Check(true); | |||
| var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
| @@ -235,13 +235,13 @@ namespace Tensorflow | |||
| public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | |||
| { | |||
| return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s.Handle); | |||
| return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||
| } | |||
| private NodeDef GetNodeDef() | |||
| { | |||
| using var buffer = new Buffer(); | |||
| c_api.TF_OperationToNodeDef(_handle, buffer.Handle, tf.Status.Handle); | |||
| var buffer = new Buffer(); | |||
| c_api.TF_OperationToNodeDef(_handle, buffer, tf.Status); | |||
| tf.Status.Check(throwException: true); | |||
| return NodeDef.Parser.ParseFrom(buffer.ToArray()); | |||
| } | |||
| @@ -50,7 +50,7 @@ namespace Tensorflow | |||
| public Operation FinishOperation(Status status) | |||
| { | |||
| return c_api.TF_FinishOperation(_handle, status.Handle); | |||
| return c_api.TF_FinishOperation(_handle, status); | |||
| } | |||
| public static implicit operator OperationDescription(IntPtr handle) | |||
| @@ -96,7 +96,7 @@ namespace Tensorflow | |||
| /// <param name="oper_name">const char*</param> | |||
| /// <returns>TF_OperationDescription*</returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||
| public static extern IntPtr TF_NewOperation(SafeGraphHandle graph, string opType, string oper_name); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_OperationDevice(IntPtr oper); | |||
| @@ -14,281 +14,272 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using Tensorflow.NumPy; | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Numerics; | |||
| using System.Text; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| namespace Tensorflow; | |||
| public class BaseSession : IDisposable | |||
| { | |||
| public class BaseSession : DisposableObject | |||
| protected SafeSessionHandle _handle; | |||
| protected Graph _graph; | |||
| protected Status _status; | |||
| public Graph graph => _graph; | |||
| public BaseSession(SafeSessionHandle handle, Graph g) | |||
| { | |||
| protected Graph _graph; | |||
| protected Status _status; | |||
| public Graph graph => _graph; | |||
| _handle = handle; | |||
| _graph = g ?? ops.get_default_graph(); | |||
| } | |||
| public BaseSession(IntPtr handle, Graph g) | |||
| public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | |||
| { | |||
| _graph = g ?? ops.get_default_graph(); | |||
| if (!_graph.building_function) | |||
| { | |||
| _handle = handle; | |||
| _graph = g ?? ops.get_default_graph(); | |||
| if (ops.get_default_graph() != _graph) | |||
| _graph.as_default(); | |||
| } | |||
| var opts = new SessionOptions(target, config); | |||
| _status = status ?? tf.Status; | |||
| _handle = c_api.TF_NewSession(_graph, opts, _status); | |||
| _status.Check(true); | |||
| } | |||
| public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | |||
| { | |||
| _graph = g ?? ops.get_default_graph(); | |||
| if (!_graph.building_function) | |||
| { | |||
| if (ops.get_default_graph() != _graph) | |||
| _graph.as_default(); | |||
| } | |||
| using var opts = new SessionOptions(target, config); | |||
| _status = status ?? tf.Status; | |||
| _handle = c_api.TF_NewSession(_graph, opts.Handle, _status.Handle); | |||
| _status.Check(true); | |||
| } | |||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
| { | |||
| _run(op, feed_dict); | |||
| } | |||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
| { | |||
| _run(op, feed_dict); | |||
| } | |||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetche, feed_dict)[0]; | |||
| } | |||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetche, feed_dict)[0]; | |||
| } | |||
| public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(fetche, feed_dict); | |||
| return fetche is Tensor ? results[0] : null; | |||
| } | |||
| public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(fetche, feed_dict); | |||
| return fetche is Tensor ? results[0] : null; | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run( | |||
| (ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, | |||
| params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict); | |||
| return (results[0], results[1], results[2], results[3], results[4]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run( | |||
| (ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, | |||
| params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict); | |||
| return (results[0], results[1], results[2], results[3], results[4]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
| return (results[0], results[1], results[2], results[3]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
| return (results[0], results[1], results[2], results[3]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
| return (results[0], results[1], results[2]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
| return (results[0], results[1], results[2]); | |||
| } | |||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
| return (results[0], results[1]); | |||
| } | |||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
| return (results[0], results[1]); | |||
| } | |||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||
| { | |||
| var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
| return _run(fetches, feed_items); | |||
| } | |||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||
| { | |||
| var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
| return _run(fetches, feed_items); | |||
| } | |||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||
| //var feed_map = new Dictionary<object, object>(); | |||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||
| // Validate and process feed_dict. | |||
| if (feed_dict != null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||
| //var feed_map = new Dictionary<object, object>(); | |||
| // Validate and process feed_dict. | |||
| if (feed_dict != null) | |||
| foreach (var subfeed in feed_dict) | |||
| { | |||
| foreach (var subfeed in feed_dict) | |||
| { | |||
| var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||
| //var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed.Value; | |||
| //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||
| } | |||
| var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||
| //var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed.Value; | |||
| //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||
| } | |||
| } | |||
| // Create a fetch handler to take care of the structure of fetches. | |||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
| // Create a fetch handler to take care of the structure of fetches. | |||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
| // Run request and get response. | |||
| // We need to keep the returned movers alive for the following _do_run(). | |||
| // These movers are no longer needed when _do_run() completes, and | |||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||
| var _ = _update_with_movers(); | |||
| var final_fetches = fetch_handler.fetches(); | |||
| var final_targets = fetch_handler.targets(); | |||
| // Run request and get response. | |||
| // We need to keep the returned movers alive for the following _do_run(). | |||
| // These movers are no longer needed when _do_run() completes, and | |||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||
| var _ = _update_with_movers(); | |||
| var final_fetches = fetch_handler.fetches(); | |||
| var final_targets = fetch_handler.targets(); | |||
| // We only want to really perform the run if fetches or targets are provided, | |||
| // or if the call is a partial run that specifies feeds. | |||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||
| // We only want to really perform the run if fetches or targets are provided, | |||
| // or if the call is a partial run that specifies feeds. | |||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||
| return fetch_handler.build_results(this, results); | |||
| } | |||
| return fetch_handler.build_results(this, results); | |||
| } | |||
| /// <summary> | |||
| /// Runs a step based on the given fetches and feeds. | |||
| /// </summary> | |||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||
| /// <param name="fetch_list"></param> | |||
| /// <param name="feed_dict"></param> | |||
| /// <returns> | |||
| /// A list of numpy ndarrays, corresponding to the elements of | |||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||
| /// name of an operation, the first Tensor output of that operation | |||
| /// will be returned for that element. | |||
| /// </returns> | |||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
| /// <summary> | |||
| /// Runs a step based on the given fetches and feeds. | |||
| /// </summary> | |||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||
| /// <param name="fetch_list"></param> | |||
| /// <param name="feed_dict"></param> | |||
| /// <returns> | |||
| /// A list of numpy ndarrays, corresponding to the elements of | |||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||
| /// name of an operation, the first Tensor output of that operation | |||
| /// will be returned for that element. | |||
| /// </returns> | |||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
| { | |||
| var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||
| int i = 0; | |||
| foreach (var x in feed_dict) | |||
| { | |||
| var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||
| int i = 0; | |||
| foreach (var x in feed_dict) | |||
| if (x.Key is Tensor key) | |||
| { | |||
| if (x.Key is Tensor key) | |||
| switch (x.Value) | |||
| { | |||
| switch (x.Value) | |||
| { | |||
| case Tensor v: | |||
| if (v.dtype != key.dtype) | |||
| throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); | |||
| break; | |||
| case SafeTensorHandle v: | |||
| var tensor = new Tensor(v); | |||
| if (tensor.dtype != key.dtype) | |||
| throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor); | |||
| break; | |||
| case bool v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case byte v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case int v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case long v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case float v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case double v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case string v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case Array v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v, v.GetShape())); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException(""); | |||
| } | |||
| case Tensor v: | |||
| if (v.dtype != key.dtype) | |||
| throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); | |||
| break; | |||
| case SafeTensorHandle v: | |||
| var tensor = new Tensor(v); | |||
| if (tensor.dtype != key.dtype) | |||
| throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor); | |||
| break; | |||
| case bool v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case byte v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case int v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case long v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case float v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case double v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case string v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
| break; | |||
| case Array v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v, v.GetShape())); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException(""); | |||
| } | |||
| else | |||
| throw new NotImplementedException(""); | |||
| } | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| //var targets = target_list; | |||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||
| else | |||
| throw new NotImplementedException(""); | |||
| } | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| //var targets = target_list; | |||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||
| } | |||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
| { | |||
| // Ensure any changes to the graph are reflected in the runtime. | |||
| _extend_graph(); | |||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
| { | |||
| // Ensure any changes to the graph are reflected in the runtime. | |||
| _extend_graph(); | |||
| c_api.TF_SessionRun(_handle, | |||
| run_options: null, | |||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
| input_values: feed_dict.Select(f => f.Value.Handle.DangerousGetHandle()).ToArray(), | |||
| ninputs: feed_dict.Length, | |||
| outputs: fetch_list, | |||
| output_values: output_values, | |||
| noutputs: fetch_list.Length, | |||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
| ntargets: target_list.Count, | |||
| run_metadata: IntPtr.Zero, | |||
| status: _status.Handle); | |||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
| _status.Check(true); | |||
| c_api.TF_SessionRun(_handle, | |||
| run_options: null, | |||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
| input_values: feed_dict.Select(f => f.Value.Handle.DangerousGetHandle()).ToArray(), | |||
| ninputs: feed_dict.Length, | |||
| outputs: fetch_list, | |||
| output_values: output_values, | |||
| noutputs: fetch_list.Length, | |||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
| ntargets: target_list.Count, | |||
| run_metadata: IntPtr.Zero, | |||
| status: _status); | |||
| var result = new NDArray[fetch_list.Length]; | |||
| _status.Check(true); | |||
| for (int i = 0; i < fetch_list.Length; i++) | |||
| result[i] = fetchValue(new SafeTensorHandle(output_values[i])); | |||
| var result = new NDArray[fetch_list.Length]; | |||
| return result; | |||
| } | |||
| for (int i = 0; i < fetch_list.Length; i++) | |||
| result[i] = fetchValue(new SafeTensorHandle(output_values[i])); | |||
| public unsafe Tensor eval(Tensor tensor) | |||
| { | |||
| var output_values = new IntPtr[1]; | |||
| var fetch_list = new[] { tensor._as_tf_output() }; | |||
| c_api.TF_SessionRun(_handle, | |||
| run_options: null, | |||
| inputs: new TF_Output[0], | |||
| input_values: new IntPtr[0], | |||
| ninputs: 0, | |||
| outputs: fetch_list, | |||
| output_values: output_values, | |||
| noutputs: 1, | |||
| target_opers: new IntPtr[0], | |||
| ntargets: 0, | |||
| run_metadata: IntPtr.Zero, | |||
| status: _status.Handle); | |||
| _status.Check(true); | |||
| return new Tensor(new SafeTensorHandle(output_values[0])); | |||
| } | |||
| return result; | |||
| } | |||
| private static unsafe NDArray fetchValue(SafeTensorHandle output) | |||
| { | |||
| var tensor = new Tensor(output); | |||
| return tensor.numpy(); | |||
| } | |||
| public unsafe Tensor eval(Tensor tensor) | |||
| { | |||
| var output_values = new IntPtr[1]; | |||
| var fetch_list = new[] { tensor._as_tf_output() }; | |||
| c_api.TF_SessionRun(_handle, | |||
| run_options: null, | |||
| inputs: new TF_Output[0], | |||
| input_values: new IntPtr[0], | |||
| ninputs: 0, | |||
| outputs: fetch_list, | |||
| output_values: output_values, | |||
| noutputs: 1, | |||
| target_opers: new IntPtr[0], | |||
| ntargets: 0, | |||
| run_metadata: IntPtr.Zero, | |||
| status: _status); | |||
| _status.Check(true); | |||
| return new Tensor(new SafeTensorHandle(output_values[0])); | |||
| } | |||
| /// <summary> | |||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||
| /// we move the tensor to the right device, generate a new tensor handle, | |||
| /// and update feed_dict to use the new handle. | |||
| /// </summary> | |||
| private List<object> _update_with_movers() | |||
| { | |||
| return new List<object> { }; | |||
| } | |||
| private static unsafe NDArray fetchValue(SafeTensorHandle output) | |||
| { | |||
| var tensor = new Tensor(output); | |||
| return tensor.numpy(); | |||
| } | |||
| private void _extend_graph() | |||
| { } | |||
| /// <summary> | |||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||
| /// we move the tensor to the right device, generate a new tensor handle, | |||
| /// and update feed_dict to use the new handle. | |||
| /// </summary> | |||
| private List<object> _update_with_movers() | |||
| { | |||
| return new List<object> { }; | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| // c_api.TF_CloseSession(handle, tf.Status.Handle); | |||
| c_api.TF_DeleteSession(handle, _status.Handle); | |||
| } | |||
| private void _extend_graph() | |||
| { } | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,46 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Net.NetworkInformation; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| { | |||
| public sealed class SafeSessionHandle : SafeTensorflowHandle | |||
| { | |||
| private SafeSessionHandle() | |||
| { | |||
| } | |||
| public SafeSessionHandle(IntPtr handle) | |||
| : base(handle) | |||
| { | |||
| } | |||
| public override string ToString() | |||
| => $"0x{handle:x16}"; | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| var status = new Status(); | |||
| // c_api.TF_CloseSession(handle, tf.Status.Handle); | |||
| c_api.TF_DeleteSession(handle, status); | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,75 +14,49 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.IO; | |||
| using System.Runtime.CompilerServices; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow; | |||
| namespace Tensorflow | |||
| public class Session : BaseSession | |||
| { | |||
| public class Session : BaseSession | |||
| { | |||
| public Session(string target = "", Graph g = null) : base(target, g, null) | |||
| { } | |||
| public Session(IntPtr handle, Graph g = null) : base(handle, g) | |||
| { } | |||
| public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) | |||
| { } | |||
| public Session as_default() | |||
| { | |||
| return ops.set_default_session(this); | |||
| } | |||
| public static Session LoadFromSavedModel(string path) | |||
| { | |||
| var graph = new Graph(); | |||
| using var status = new Status(); | |||
| using var opt = c_api.TF_NewSessionOptions(); | |||
| var tags = new string[] { "serve" }; | |||
| var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
| IntPtr.Zero, | |||
| path, | |||
| tags, | |||
| tags.Length, | |||
| graph, | |||
| IntPtr.Zero, | |||
| status.Handle); | |||
| status.Check(true); | |||
| // load graph bytes | |||
| // var data = new byte[buffer.length]; | |||
| // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||
| // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||
| return new Session(sess, g: graph); | |||
| } | |||
| public static implicit operator IntPtr(Session session) => session._handle; | |||
| public static implicit operator Session(IntPtr handle) => new Session(handle); | |||
| public Session(string target = "", Graph g = null) : base(target, g, null) | |||
| { } | |||
| public void __enter__() | |||
| { | |||
| public Session(SafeSessionHandle handle, Graph g = null) : base(handle, g) | |||
| { } | |||
| } | |||
| public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) | |||
| { } | |||
| public void __exit__() | |||
| { | |||
| } | |||
| public void __init__() | |||
| { | |||
| } | |||
| public void __del__() | |||
| { | |||
| public Session as_default() | |||
| { | |||
| return ops.set_default_session(this); | |||
| } | |||
| } | |||
| public static Session LoadFromSavedModel(string path) | |||
| { | |||
| var graph = new Graph(); | |||
| var status = new Status(); | |||
| using var opt = c_api.TF_NewSessionOptions(); | |||
| var tags = new string[] { "serve" }; | |||
| var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
| IntPtr.Zero, | |||
| path, | |||
| tags, | |||
| tags.Length, | |||
| graph, | |||
| IntPtr.Zero, | |||
| status); | |||
| status.Check(true); | |||
| // load graph bytes | |||
| // var data = new byte[buffer.length]; | |||
| // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||
| // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||
| return new Session(sess, g: graph); | |||
| } | |||
| public static implicit operator SafeSessionHandle(Session session) => session._handle; | |||
| public static implicit operator Session(SafeSessionHandle handle) => new Session(handle); | |||
| } | |||
| @@ -19,33 +19,33 @@ using System; | |||
| namespace Tensorflow | |||
| { | |||
| internal sealed class SessionOptions : IDisposable | |||
| internal sealed class SessionOptions | |||
| { | |||
| public SafeSessionOptionsHandle Handle { get; } | |||
| SafeSessionOptionsHandle _handle { get; } | |||
| public SessionOptions(string target = "", ConfigProto config = null) | |||
| { | |||
| Handle = c_api.TF_NewSessionOptions(); | |||
| c_api.TF_SetTarget(Handle, target); | |||
| _handle = c_api.TF_NewSessionOptions(); | |||
| c_api.TF_SetTarget(_handle, target); | |||
| if (config != null) | |||
| SetConfig(config); | |||
| } | |||
| public void Dispose() | |||
| => Handle.Dispose(); | |||
| private unsafe void SetConfig(ConfigProto config) | |||
| { | |||
| var bytes = config.ToByteArray(); | |||
| fixed (byte* proto2 = bytes) | |||
| { | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.TF_SetConfig(Handle, (IntPtr)proto2, (ulong)bytes.Length, status.Handle); | |||
| status.Check(false); | |||
| } | |||
| var status = new Status(); | |||
| c_api.TF_SetConfig(_handle, (IntPtr)proto2, (ulong)bytes.Length, status); | |||
| status.Check(false); | |||
| } | |||
| } | |||
| public static implicit operator SafeSessionOptionsHandle(SessionOptions opt) | |||
| { | |||
| return opt._handle; | |||
| } | |||
| } | |||
| } | |||
| @@ -62,7 +62,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns>TF_Session*</returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_NewSession(IntPtr graph, SafeSessionOptionsHandle opts, SafeStatusHandle status); | |||
| public static extern SafeSessionHandle TF_NewSession(SafeGraphHandle graph, SafeSessionOptionsHandle opts, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Return a new options object. | |||
| @@ -110,7 +110,7 @@ namespace Tensorflow | |||
| /// <param name="run_metadata">TF_Buffer*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | |||
| public static extern unsafe void TF_SessionRun(SafeSessionHandle session, TF_Buffer* run_options, | |||
| TF_Output[] inputs, IntPtr[] input_values, int ninputs, | |||
| TF_Output[] outputs, IntPtr[] output_values, int noutputs, | |||
| IntPtr[] target_opers, int ntargets, | |||
| @@ -26,7 +26,7 @@ namespace Tensorflow | |||
| /// TF_Status holds error information. It either has an OK code, or | |||
| /// else an error code with an associated error message. | |||
| /// </summary> | |||
| public sealed class Status : IDisposable | |||
| public sealed class Status | |||
| { | |||
| /// <summary> | |||
| /// Error message | |||
| @@ -35,9 +35,9 @@ namespace Tensorflow | |||
| { | |||
| get | |||
| { | |||
| using (Handle.Lease()) | |||
| using (_handle.Lease()) | |||
| { | |||
| return StringPiece(TF_Message(Handle)); | |||
| return StringPiece(TF_Message(_handle)); | |||
| } | |||
| } | |||
| } | |||
| @@ -45,23 +45,23 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Error code | |||
| /// </summary> | |||
| public TF_Code Code => TF_GetCode(Handle); | |||
| public TF_Code Code => TF_GetCode(_handle); | |||
| public SafeStatusHandle Handle { get; } | |||
| SafeStatusHandle _handle { get; } | |||
| public Status() | |||
| { | |||
| Handle = TF_NewStatus(); | |||
| _handle = TF_NewStatus(); | |||
| } | |||
| public Status(SafeStatusHandle handle) | |||
| { | |||
| Handle = handle ?? throw new ArgumentNullException(nameof(handle)); | |||
| _handle = handle ?? throw new ArgumentNullException(nameof(handle)); | |||
| } | |||
| public void SetStatus(TF_Code code, string msg) | |||
| { | |||
| TF_SetStatus(Handle, code, msg); | |||
| TF_SetStatus(_handle, code, msg); | |||
| } | |||
| public bool ok() => Code == TF_Code.TF_OK; | |||
| @@ -94,10 +94,12 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public void Dispose() | |||
| => Handle.Dispose(); | |||
| public override string ToString() | |||
| => $"{Code} 0x{Handle.DangerousGetHandle():x16}"; | |||
| => $"{Code} 0x{_handle.DangerousGetHandle():x16}"; | |||
| public static implicit operator SafeStatusHandle(Status status) | |||
| { | |||
| return status._handle; | |||
| } | |||
| } | |||
| } | |||
| @@ -121,7 +121,7 @@ namespace Tensorflow | |||
| if (_handle == null) | |||
| { | |||
| c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle); | |||
| c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status); | |||
| } | |||
| else | |||
| { | |||
| @@ -135,9 +135,9 @@ namespace Tensorflow | |||
| protected virtual void SetShapeInternal(Shape value) | |||
| { | |||
| if (value == null) | |||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); | |||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status); | |||
| else | |||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); | |||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status); | |||
| } | |||
| public int[] _shape_tuple() | |||
| @@ -176,7 +176,7 @@ namespace Tensorflow | |||
| if (_handle == null) | |||
| { | |||
| var output = _as_tf_output(); | |||
| int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status.Handle); | |||
| int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status); | |||
| return ndim; | |||
| } | |||
| @@ -94,18 +94,16 @@ namespace Tensorflow | |||
| string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb")); | |||
| using (var graph = tf.Graph()) | |||
| using (var sess = tf.Session(graph)) | |||
| { | |||
| var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true); | |||
| saver.restore(sess, checkpoint); | |||
| var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, | |||
| graph.as_graph_def(), | |||
| output_node_names); | |||
| Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); | |||
| File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); | |||
| return output_pb; | |||
| } | |||
| var graph = tf.Graph(); | |||
| var sess = tf.Session(graph); | |||
| var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true); | |||
| saver.restore(sess, checkpoint); | |||
| var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, | |||
| graph.as_graph_def(), | |||
| output_node_names); | |||
| Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); | |||
| File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); | |||
| return output_pb; | |||
| } | |||
| public static Graph load_graph(string freeze_graph_pb, string name = "") | |||
| @@ -164,7 +164,7 @@ namespace Tensorflow | |||
| result._as_tf_output(), | |||
| shape.dims, | |||
| shape.ndim, | |||
| tf.Status.Handle); | |||
| tf.Status); | |||
| tf.Status.Check(true); | |||
| } | |||
| @@ -247,7 +247,7 @@ namespace Tensorflow | |||
| foreach (var attr in node_def.Attr) | |||
| { | |||
| var bytes = attr.Value.ToByteArray(); | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle); | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status); | |||
| status.Check(true); | |||
| } | |||
| @@ -23,16 +23,14 @@ namespace Tensorflow.Benchmark.Leak | |||
| var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); | |||
| for (var i = 0; i < 1024; i++) | |||
| { | |||
| using (var sess = Session.LoadFromSavedModel(ClassifierModelPath)) { | |||
| using (var g = sess.graph.as_default()) { | |||
| var inputOp = g.OperationByName("inference_input"); | |||
| var outputOp = g.OperationByName("StatefulPartitionedCall"); | |||
| { | |||
| var sess = Session.LoadFromSavedModel(ClassifierModelPath); | |||
| var g = sess.graph.as_default(); | |||
| var inputOp = g.OperationByName("inference_input"); | |||
| var outputOp = g.OperationByName("StatefulPartitionedCall"); | |||
| var inp = np.zeros(new Shape(new int[] { 1, 2, 96 }), TF_DataType.TF_FLOAT); | |||
| sess.run(outputOp.outputs[0], new FeedItem(inputOp.outputs[0], inp)); | |||
| } | |||
| } | |||
| var inp = np.zeros(new Shape(new int[] { 1, 2, 96 }), TF_DataType.TF_FLOAT); | |||
| sess.run(outputOp.outputs[0], new FeedItem(inputOp.outputs[0], inp)); | |||
| } | |||
| } | |||
| } | |||
| @@ -16,18 +16,16 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| var enqueue = queue.enqueue(numbers); | |||
| var dequeue_many = queue.dequeue_many(n: 3); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| sess.run(enqueue, (numbers, new[] { 1 })); | |||
| sess.run(enqueue, (numbers, new[] { 2, 3 })); | |||
| sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); | |||
| var result = sess.run(dequeue_many[0]); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>())); | |||
| } | |||
| var sess = tf.Session(); | |||
| sess.run(enqueue, (numbers, new[] { 1 })); | |||
| sess.run(enqueue, (numbers, new[] { 2, 3 })); | |||
| sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); | |||
| var result = sess.run(dequeue_many[0]); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>())); | |||
| } | |||
| [TestMethod] | |||
| @@ -45,27 +43,25 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| // push back into queue | |||
| var inc = queue.enqueue(y); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| // init queue | |||
| init.run(); | |||
| var sess = tf.Session(); | |||
| // init queue | |||
| init.run(); | |||
| // pop out first element and push back calculated y | |||
| (int dequeued, _) = sess.run((x, inc)); | |||
| Assert.AreEqual(10, dequeued); | |||
| // pop out first element and push back calculated y | |||
| (int dequeued, _) = sess.run((x, inc)); | |||
| Assert.AreEqual(10, dequeued); | |||
| (dequeued, _) = sess.run((x, inc)); | |||
| Assert.AreEqual(20, dequeued); | |||
| (dequeued, _) = sess.run((x, inc)); | |||
| Assert.AreEqual(20, dequeued); | |||
| (dequeued, _) = sess.run((x, inc)); | |||
| Assert.AreEqual(11, dequeued); | |||
| (dequeued, _) = sess.run((x, inc)); | |||
| Assert.AreEqual(11, dequeued); | |||
| (dequeued, _) = sess.run((x, inc)); | |||
| Assert.AreEqual(21, dequeued); | |||
| (dequeued, _) = sess.run((x, inc)); | |||
| Assert.AreEqual(21, dequeued); | |||
| // thread will hang or block if you run sess.run(x) again | |||
| // until queue has more element. | |||
| } | |||
| // thread will hang or block if you run sess.run(x) again | |||
| // until queue has more element. | |||
| } | |||
| [TestMethod] | |||
| @@ -75,19 +71,17 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | |||
| var x = queue.dequeue(); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| init.run(); | |||
| var sess = tf.Session(); | |||
| init.run(); | |||
| var result = sess.run(x); | |||
| Assert.AreEqual(result[0], 2L); | |||
| var result = sess.run(x); | |||
| Assert.AreEqual(result[0], 2L); | |||
| result = sess.run(x); | |||
| Assert.AreEqual(result[0], 3L); | |||
| result = sess.run(x); | |||
| Assert.AreEqual(result[0], 3L); | |||
| result = sess.run(x); | |||
| Assert.AreEqual(result[0], 4L); | |||
| } | |||
| result = sess.run(x); | |||
| Assert.AreEqual(result[0], 4L); | |||
| } | |||
| [TestMethod] | |||
| @@ -98,16 +92,14 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| var x = queue.dequeue(); | |||
| string results = ""; | |||
| using (var sess = tf.Session()) | |||
| { | |||
| init.run(); | |||
| var sess = tf.Session(); | |||
| init.run(); | |||
| foreach (var i in range(9)) | |||
| results += (int)sess.run(x) + "."; | |||
| foreach (var i in range(9)) | |||
| results += (int)sess.run(x) + "."; | |||
| // output in random order | |||
| Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9."); | |||
| } | |||
| // output in random order | |||
| Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9."); | |||
| } | |||
| } | |||
| } | |||
| @@ -19,11 +19,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| var a = constant_op.constant(np.array(3.0).reshape((1, 1))); | |||
| var b = constant_op.constant(np.array(2.0).reshape((1, 1))); | |||
| var c = math_ops.matmul(a, b, name: "matmul"); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = c.eval(sess); | |||
| Assert.AreEqual(result[0], 6.0); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = c.eval(sess); | |||
| Assert.AreEqual(result[0], 6.0); | |||
| } | |||
| } | |||
| @@ -32,11 +30,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| { | |||
| var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING); | |||
| var c = tf.strings.substr(a, 4, 8); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = c.eval(sess).StringData(); | |||
| Assert.AreEqual(result[0], "heythere"); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = c.eval(sess).StringData(); | |||
| Assert.AreEqual(result[0], "heythere"); | |||
| } | |||
| [TestMethod] | |||
| @@ -47,11 +43,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| const int size = 30_000; | |||
| var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING); | |||
| var c = tf.strings.substr(a, 0, size - 5000); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray()); | |||
| Console.WriteLine(result); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray()); | |||
| Console.WriteLine(result); | |||
| } | |||
| } | |||
| @@ -16,15 +16,13 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1); | |||
| var st = tf.concat(values: new[] { indices, labels }, axis: 1); | |||
| var onehot = tf.sparse_to_dense(st, (5, 5), 1); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(onehot); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>())); | |||
| }; | |||
| var sess = tf.Session(); | |||
| var result = sess.run(onehot); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>())); | |||
| } | |||
| [TestMethod, Ignore] | |||
| @@ -39,13 +37,11 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| new[] { 3L, 4L }); | |||
| var onehot = tf.sparse_tensor_to_dense(decoded_list); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(onehot); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>())); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = sess.run(onehot); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>())); | |||
| } | |||
| [TestMethod] | |||
| @@ -56,14 +52,12 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| int[,] crops = { { 0, 0 }, { 0, 0 } }; | |||
| var tensor = tf.batch_to_space_nd(inputs, block_shape, crops); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(tensor); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = sess.run(tensor); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>())); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); | |||
| } | |||
| [TestMethod, Ignore] | |||
| @@ -72,11 +66,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| var tensor = new[] { 0, 1, 2, 3 }; | |||
| var mask = np.array(new[] { true, false, true, false }); | |||
| var masked = tf.boolean_mask(tensor, mask); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(masked); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = sess.run(masked); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | |||
| } | |||
| } | |||
| } | |||
| @@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| var v = tf.Variable(new[] { 1, 2 }); | |||
| var init = tf.compat.v1.global_variables_initializer(); | |||
| using var sess = tf.compat.v1.Session(); | |||
| var sess = tf.compat.v1.Session(); | |||
| sess.run(init); | |||
| // Usage passing the session explicitly. | |||
| print(v.eval(sess)); | |||
| @@ -16,18 +16,16 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest | |||
| { | |||
| var graph = tf.Graph().as_default(); | |||
| using (var sess = tf.Session(graph)) | |||
| { | |||
| var x = tf.constant(2, name: "x"); | |||
| var y = tf.constant(5, name: "y"); | |||
| var z = control_flow_ops.cond(tf.less(x, y), | |||
| () => tf.constant(22, name: "t22"), | |||
| () => tf.constant(55, name: "f55")); | |||
| int result = z.eval(sess); | |||
| assertEquals(result, 22); | |||
| } | |||
| var sess = tf.Session(graph); | |||
| var x = tf.constant(2, name: "x"); | |||
| var y = tf.constant(5, name: "y"); | |||
| var z = control_flow_ops.cond(tf.less(x, y), | |||
| () => tf.constant(22, name: "t22"), | |||
| () => tf.constant(55, name: "f55")); | |||
| int result = z.eval(sess); | |||
| assertEquals(result, 22); | |||
| } | |||
| [TestMethod] | |||
| @@ -35,18 +33,16 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest | |||
| { | |||
| var graph = tf.Graph().as_default(); | |||
| using (var sess = tf.Session(graph)) | |||
| { | |||
| var x = tf.constant(2, name: "x"); | |||
| var y = tf.constant(1, name: "y"); | |||
| var sess = tf.Session(graph); | |||
| var x = tf.constant(2, name: "x"); | |||
| var y = tf.constant(1, name: "y"); | |||
| var z = control_flow_ops.cond(tf.less(x, y), | |||
| () => tf.constant(22, name: "t22"), | |||
| () => tf.constant(11, name: "f11")); | |||
| var z = control_flow_ops.cond(tf.less(x, y), | |||
| () => tf.constant(22, name: "t22"), | |||
| () => tf.constant(11, name: "f11")); | |||
| int result = z.eval(sess); | |||
| assertEquals(result, 11); | |||
| } | |||
| int result = z.eval(sess); | |||
| assertEquals(result, 11); | |||
| } | |||
| [Ignore("Dependent on UpdateEdge")] | |||
| @@ -23,21 +23,19 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest | |||
| private void _testWhileContextHelper(int maximum_iterations) | |||
| { | |||
| // TODO: implement missing code dependencies | |||
| using (var sess = this.cached_session()) | |||
| var sess = this.cached_session(); | |||
| var i = constant_op.constant(0, name: "i"); | |||
| var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||
| var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||
| //control_flow_ops.while_loop( | |||
| // c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||
| foreach (Operation op in sess.graph.get_operations()) | |||
| { | |||
| var i = constant_op.constant(0, name: "i"); | |||
| var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||
| var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||
| //control_flow_ops.while_loop( | |||
| // c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||
| foreach (Operation op in sess.graph.get_operations()) | |||
| { | |||
| var control_flow_context = op._get_control_flow_context(); | |||
| /*if (control_flow_context != null) | |||
| self.assertProtoEquals(control_flow_context.to_proto(), | |||
| WhileContext.from_proto( | |||
| control_flow_context.to_proto()).to_proto(), "");*/ | |||
| } | |||
| var control_flow_context = op._get_control_flow_context(); | |||
| /*if (control_flow_context != null) | |||
| self.assertProtoEquals(control_flow_context.to_proto(), | |||
| WhileContext.from_proto( | |||
| control_flow_context.to_proto()).to_proto(), "");*/ | |||
| } | |||
| } | |||
| @@ -18,11 +18,9 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| var y = tf.broadcast_to(x, (2, 4, 3)); | |||
| var grad = tf.gradients(y, x); | |||
| using (var sess = tf.Session(graph)) | |||
| { | |||
| float result = sess.run(grad[0]); | |||
| Assert.AreEqual(result, 24.0f); | |||
| } | |||
| var sess = tf.Session(graph); | |||
| float result = sess.run(grad[0]); | |||
| Assert.AreEqual(result, 24.0f); | |||
| } | |||
| [TestMethod] | |||
| @@ -33,11 +31,9 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| var z = tf.cumsum(y, axis: 1); | |||
| var grad = tf.gradients(z, x); | |||
| using (var sess = tf.Session(graph)) | |||
| { | |||
| float result = sess.run(grad[0]); | |||
| Assert.AreEqual(result, 60.0f); | |||
| } | |||
| var sess = tf.Session(graph); | |||
| float result = sess.run(grad[0]); | |||
| Assert.AreEqual(result, 60.0f); | |||
| } | |||
| [TestMethod, Ignore] | |||
| @@ -78,14 +74,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| 42.0f, 42.0f, 42.0f, | |||
| 45.0f, 45.0f, 45.0f | |||
| }; | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(g); | |||
| var resultList = result[0].ToArray<float>().ToList(); | |||
| resultList.AddRange(result[1].ToArray<float>()); | |||
| Console.WriteLine(result.ToString()); | |||
| CollectionAssert.AreEqual(resultList.ToArray(), checkG); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = sess.run(g); | |||
| var resultList = result[0].ToArray<float>().ToList(); | |||
| resultList.AddRange(result[1].ToArray<float>()); | |||
| Console.WriteLine(result.ToString()); | |||
| CollectionAssert.AreEqual(resultList.ToArray(), checkG); | |||
| } | |||
| [TestMethod] | |||
| @@ -97,11 +91,9 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| var y = f(x); | |||
| var g = tf.gradients(y, x); | |||
| using (var session = tf.Session()) | |||
| { | |||
| var result = session.run(new[] { y, g[0] }); | |||
| return (result[0].ToArray<T>()[0], result[1].ToArray<T>()[0]); | |||
| } | |||
| var session = tf.Session(); | |||
| var result = session.run(new[] { y, g[0] }); | |||
| return (result[0].ToArray<T>()[0], result[1].ToArray<T>()[0]); | |||
| } | |||
| void test(string name, Func<Tensor, Tensor> tfF, Func<double, (double, double)> targetF, double[] values) | |||
| @@ -197,13 +189,11 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| var g1 = tf.gradients(tf.reduce_sum(m, axis: 0)[0], x)[0]; | |||
| var g2 = tf.gradients(tf.reduce_sum(m, axis: 1)[0], x)[0]; | |||
| using (var session = tf.Session()) | |||
| { | |||
| var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } })); | |||
| self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)"); | |||
| self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); | |||
| self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); | |||
| } | |||
| var session = tf.Session(); | |||
| var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } })); | |||
| self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)"); | |||
| self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); | |||
| self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); | |||
| } | |||
| [TestMethod] | |||
| @@ -212,12 +202,10 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| var a = tf.constant(1f); | |||
| var b = tf.tanh(a); | |||
| var g = tf.gradients(b, a); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(g); | |||
| var actual = result[0]; | |||
| Assert.AreEqual(actual, 0.41997434127f); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = sess.run(g); | |||
| var actual = result[0]; | |||
| Assert.AreEqual(actual, 0.41997434127f); | |||
| } | |||
| @@ -227,14 +215,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| var a = tf.constant(5f); | |||
| var b = tf.lgamma(a); | |||
| var g = tf.gradients(b, a); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(new object[] { g, b }); | |||
| var actualDeriv = result[0]; | |||
| var actual = result[1]; | |||
| Assert.AreEqual(actualDeriv, 1.5061177f); | |||
| Assert.AreEqual(actual, 3.17805386f); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = sess.run(new object[] { g, b }); | |||
| var actualDeriv = result[0]; | |||
| var actual = result[1]; | |||
| Assert.AreEqual(actualDeriv, 1.5061177f); | |||
| Assert.AreEqual(actual, 3.17805386f); | |||
| } | |||
| [TestMethod] | |||
| @@ -247,14 +233,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| tf.constant(new[] { 1 }, tf.int32, new[] { 1 }) | |||
| ); | |||
| var g = tf.gradients(b, a); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(new object[] { g, b }); | |||
| var actualDeriv = np.squeeze(result[0]); | |||
| var actual = np.squeeze(result[1]); | |||
| Assert.AreEqual(actualDeriv, new float[] { 1, 0 }); | |||
| Assert.AreEqual(actual, 0.9640276f); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = sess.run(new object[] { g, b }); | |||
| var actualDeriv = np.squeeze(result[0]); | |||
| var actual = np.squeeze(result[1]); | |||
| Assert.AreEqual(actualDeriv, new float[] { 1, 0 }); | |||
| Assert.AreEqual(actual, 0.9640276f); | |||
| } | |||
| [TestMethod] | |||
| @@ -264,14 +248,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
| var a = tf.concat(new List<Tensor>(new[] { a1, a2 }), 0); | |||
| var g = tf.gradients(a, a1); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(new object[] { g, a }); | |||
| var actualDeriv = result[0][0]; | |||
| var actual = result[1][0]; | |||
| Assert.AreEqual(actualDeriv, 1f); | |||
| Assert.AreEqual(actual, 2f); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = sess.run(new object[] { g, a }); | |||
| var actualDeriv = result[0][0]; | |||
| var actual = result[1][0]; | |||
| Assert.AreEqual(actualDeriv, 1f); | |||
| Assert.AreEqual(actual, 2f); | |||
| } | |||
| [TestMethod] | |||
| @@ -280,13 +262,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| var ap = tf.constant(1f); | |||
| var b = tf.tanh(ap) + gen_array_ops.stop_gradient(ap); | |||
| var g = tf.gradients(b, ap); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(g); | |||
| var actual = result[0]; | |||
| Assert.AreEqual(actual, 0.41997434127f); | |||
| } | |||
| var sess = tf.Session(); | |||
| var result = sess.run(g); | |||
| var actual = result[0]; | |||
| Assert.AreEqual(actual, 0.41997434127f); | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testUnusedOutput() | |||
| @@ -74,23 +74,21 @@ namespace TensorFlowNET.UnitTest | |||
| var cropSize2_2 = tf.Variable(np.array(4, 4)); | |||
| var init = tf.global_variables_initializer(); | |||
| using (Session sess = tf.Session()) | |||
| { | |||
| sess.run(init); | |||
| var sess = tf.Session(); | |||
| sess.run(init); | |||
| var cropped = tf.image.crop_and_resize(image, box, boxInd, cropSize1_1); | |||
| var cropped = tf.image.crop_and_resize(image, box, boxInd, cropSize1_1); | |||
| var result = sess.run(cropped); | |||
| // check if cropped to 1x1 center was succesfull | |||
| Assert.AreEqual(result.size, 1ul); | |||
| Assert.AreEqual(result[0, 0, 0, 0], 4f); | |||
| var result = sess.run(cropped); | |||
| // check if cropped to 1x1 center was succesfull | |||
| Assert.AreEqual(result.size, 1ul); | |||
| Assert.AreEqual(result[0, 0, 0, 0], 4f); | |||
| cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); | |||
| result = sess.run(cropped); | |||
| // check if flipped and no cropping occured | |||
| Assert.AreEqual(result.size, 16ul); | |||
| Assert.AreEqual(result[0, 0, 0, 0], 12f); | |||
| } | |||
| cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); | |||
| result = sess.run(cropped); | |||
| // check if flipped and no cropping occured | |||
| Assert.AreEqual(result.size, 16ul); | |||
| Assert.AreEqual(result[0, 0, 0, 0], 12f); | |||
| } | |||
| } | |||
| } | |||
| @@ -24,7 +24,7 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| Assert.IsNull(tf.peak_default_graph()); | |||
| using var sess = tf.Session(); | |||
| var sess = tf.Session(); | |||
| var default_graph = tf.get_default_graph(); | |||
| var sess_graph = sess.graph; | |||
| Assert.IsNotNull(default_graph); | |||
| @@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| Assert.IsNull(tf.peak_default_graph()); | |||
| //tf.Session created an other graph | |||
| using var sess = tf.Session(); | |||
| var sess = tf.Session(); | |||
| var default_graph = tf.get_default_graph(); | |||
| var sess_graph = sess.graph; | |||
| Assert.IsNotNull(default_graph); | |||
| @@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest | |||
| beforehand.as_default(); | |||
| Assert.IsNotNull(tf.peak_default_graph()); | |||
| using var sess = tf.Session(); | |||
| var sess = tf.Session(); | |||
| var default_graph = tf.peak_default_graph(); | |||
| var sess_graph = sess.graph; | |||
| Assert.IsNotNull(default_graph); | |||
| @@ -102,7 +102,7 @@ namespace TensorFlowNET.UnitTest | |||
| //the core method | |||
| void Core(int tid) | |||
| { | |||
| using var sess = tf.Session(); | |||
| var sess = tf.Session(); | |||
| for (int i = 0; i < 100; i++) | |||
| { | |||
| var t = new Tensor(1); | |||
| @@ -119,7 +119,7 @@ namespace TensorFlowNET.UnitTest | |||
| void Core(int tid) | |||
| { | |||
| //tf.Session created an other graph | |||
| using var sess = tf.Session(); | |||
| var sess = tf.Session(); | |||
| for (int i = 0; i < 100; i++) | |||
| { | |||
| var t = new Tensor(new int[] { 1, 2, 3 }); | |||
| @@ -142,7 +142,7 @@ namespace TensorFlowNET.UnitTest | |||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
| var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
| var math = a1 + a2; | |||
| using var sess = tf.Session(graph); | |||
| var sess = tf.Session(graph); | |||
| for (int i = 0; i < 100; i++) | |||
| { | |||
| var result = sess.run(math); | |||
| @@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest | |||
| tf.compat.v1.disable_eager_execution(); | |||
| var graph = tf.Graph().as_default(); | |||
| using var sess = tf.Session(graph); | |||
| var sess = tf.Session(graph); | |||
| Assert.IsNotNull(tf.get_default_graph()); | |||
| //graph is created automatically to perform create these operations | |||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
| @@ -182,7 +182,7 @@ namespace TensorFlowNET.UnitTest | |||
| //the core method | |||
| void Core(int tid) | |||
| { | |||
| using var sess = tf.Session(); | |||
| var sess = tf.Session(); | |||
| Assert.IsNotNull(tf.get_default_graph()); | |||
| //graph is created automatically to perform create these operations | |||
| var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
| @@ -182,23 +182,21 @@ namespace TensorFlowNET.UnitTest | |||
| // return self._eval_helper(tensors) | |||
| // else: | |||
| { | |||
| using (var sess = tf.Session()) | |||
| var sess = tf.Session(); | |||
| var ndarray = tensor.eval(sess); | |||
| if (typeof(T) == typeof(double)) | |||
| { | |||
| var ndarray = tensor.eval(sess); | |||
| if (typeof(T) == typeof(double)) | |||
| { | |||
| double x = ndarray; | |||
| result = x; | |||
| } | |||
| else if (typeof(T) == typeof(int)) | |||
| { | |||
| int x = ndarray; | |||
| result = x; | |||
| } | |||
| else | |||
| { | |||
| result = ndarray; | |||
| } | |||
| double x = ndarray; | |||
| result = x; | |||
| } | |||
| else if (typeof(T) == typeof(int)) | |||
| { | |||
| int x = ndarray; | |||
| result = x; | |||
| } | |||
| else | |||
| { | |||
| result = ndarray; | |||
| } | |||
| return (T)result; | |||
| @@ -48,7 +48,7 @@ namespace Tensorflow.Native.UnitTest | |||
| private void EXPECT_TF_META(Operation oper, string attr_name, int expected_list_size, TF_AttrType expected_type, uint expected_total_size) | |||
| { | |||
| var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_.Handle); | |||
| var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_); | |||
| EXPECT_EQ(TF_Code.TF_OK, s_.Code); | |||
| char e = expected_list_size >= 0 ? (char)1 : (char)0; | |||
| /*EXPECT_EQ(e, m.is_list); | |||
| @@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest | |||
| var desc = init("string"); | |||
| c_api.TF_SetAttrString(desc, "v", "bunny", 5); | |||
| var oper = c_api.TF_FinishOperation(desc, s_.Handle); | |||
| var oper = c_api.TF_FinishOperation(desc, s_); | |||
| //ASSERT_EQ(TF_Code.TF_OK, s_.Code); | |||
| //EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); | |||
| //var value = new char[5]; | |||
| @@ -86,8 +86,6 @@ namespace Tensorflow.Native.UnitTest | |||
| public void Dispose() | |||
| { | |||
| graph_.Dispose(); | |||
| s_.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -59,7 +59,7 @@ namespace Tensorflow.Native.UnitTest | |||
| private void VerifyCollocation(Operation op, string[] expected) | |||
| { | |||
| var handle = c_api.TF_OperationGetAttrMetadata(op, "_class", s_.Handle); | |||
| var handle = c_api.TF_OperationGetAttrMetadata(op, "_class", s_); | |||
| TF_AttrMetadata m = new TF_AttrMetadata(); | |||
| if (expected.Length == 0) | |||
| { | |||
| @@ -98,8 +98,6 @@ namespace Tensorflow.Native.UnitTest | |||
| public void Dispose() | |||
| { | |||
| graph_.Dispose(); | |||
| s_.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -45,10 +45,10 @@ namespace Tensorflow.Native.UnitTest | |||
| => c_api.TF_AddInput(desc, input); | |||
| protected Operation TF_FinishOperation(OperationDescription desc, Status s) | |||
| => c_api.TF_FinishOperation(desc, s.Handle); | |||
| => c_api.TF_FinishOperation(desc, s); | |||
| protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s) | |||
| => c_api.TF_SetAttrTensor(desc, attrName, value, s.Handle); | |||
| => c_api.TF_SetAttrTensor(desc, attrName, value, s); | |||
| protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype) | |||
| => c_api.TF_SetAttrType(desc, attrName, dtype); | |||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Native.UnitTest | |||
| string func_name_ = "MyFunc"; | |||
| string func_node_name_ = "MyFunc_0"; | |||
| Status s_; | |||
| IntPtr func_; | |||
| SafeFuncGraphHandle func_; | |||
| [TestInitialize] | |||
| public void Initialize() | |||
| @@ -402,7 +402,7 @@ namespace Tensorflow.Native.UnitTest | |||
| inputs.Length, inputs.ToArray(), | |||
| outputs.Length, outputs.ToArray(), | |||
| output_names == null || output_names.Length == 0 ? null : output_names, | |||
| IntPtr.Zero, null, s_.Handle); | |||
| IntPtr.Zero, null, s_); | |||
| if (expect_failure) | |||
| { | |||
| @@ -413,7 +413,7 @@ namespace Tensorflow.Native.UnitTest | |||
| ASSERT_EQ(TF_OK, s_.Code, s_.Message); | |||
| ASSERT_NE(func_, IntPtr.Zero); | |||
| ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_))); | |||
| c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_.Handle); | |||
| c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_); | |||
| ASSERT_EQ(TF_OK, s_.Code, s_.Message); | |||
| } | |||
| @@ -44,18 +44,14 @@ namespace Tensorflow.Native.UnitTest | |||
| private bool GetGraphDef(Graph graph, out GraphDef graph_def) | |||
| { | |||
| graph_def = null; | |||
| using (var s = new Status()) | |||
| { | |||
| using (var buffer = new Buffer()) | |||
| { | |||
| c_api.TF_GraphToGraphDef(graph, buffer.Handle, s.Handle); | |||
| bool ret = TF_GetCode(s) == TF_OK; | |||
| EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||
| if (ret) | |||
| graph_def = GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||
| return ret; | |||
| } | |||
| } | |||
| var s = new Status(); | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
| bool ret = TF_GetCode(s) == TF_OK; | |||
| EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||
| if (ret) | |||
| graph_def = GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||
| return ret; | |||
| } | |||
| private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) | |||
| @@ -111,9 +107,9 @@ namespace Tensorflow.Native.UnitTest | |||
| IntPtr[] handles = new IntPtr[2] { IntPtr.Zero, IntPtr.Zero }; | |||
| c_api.TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, | |||
| ninputs, grad_inputs, s_.Handle, handles); | |||
| ninputs, grad_inputs, s_, handles); | |||
| var op = new Operation(handles[0]); | |||
| // var op = new Operation(handles[0]); | |||
| } | |||
| else | |||
| { | |||
| @@ -275,9 +271,6 @@ namespace Tensorflow.Native.UnitTest | |||
| public void Dispose() | |||
| { | |||
| graph_.Dispose(); | |||
| expected_graph_.Dispose(); | |||
| s_.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -9,7 +9,7 @@ namespace Tensorflow.Native.UnitTest | |||
| [TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")] | |||
| public void UpdateEdge() | |||
| { | |||
| using var graph = new Graph().as_default(); | |||
| var graph = new Graph().as_default(); | |||
| var one = tf.constant(1, name: "one"); | |||
| var two = tf.constant(2, name: "two"); | |||
| @@ -35,7 +35,7 @@ namespace Tensorflow.Native.UnitTest | |||
| EXPECT_EQ(attr_value.Type, DataType.DtInt32); | |||
| // Test not found errors in TF_Operation*() query functions. | |||
| EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s.Handle)); | |||
| EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||
| EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||
| Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||
| EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); | |||
| @@ -191,9 +191,6 @@ namespace Tensorflow.Native.UnitTest | |||
| ASSERT_TRUE(found_scalar_const); | |||
| ASSERT_TRUE(found_add); | |||
| ASSERT_TRUE(found_neg); | |||
| graph.Dispose(); | |||
| s.Dispose(); | |||
| } | |||
| /// <summary> | |||
| @@ -213,16 +210,15 @@ namespace Tensorflow.Native.UnitTest | |||
| // Export to a GraphDef. | |||
| var graph_def = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle); | |||
| c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Import it, with a prefix, in a fresh graph. | |||
| graph.Dispose(); | |||
| graph = new Graph().as_default(); | |||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| } | |||
| @@ -265,7 +261,7 @@ namespace Tensorflow.Native.UnitTest | |||
| EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | |||
| EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||
| var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle); | |||
| var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| return results; | |||
| @@ -305,7 +301,7 @@ namespace Tensorflow.Native.UnitTest | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); | |||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); | |||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| } | |||
| @@ -330,7 +326,7 @@ namespace Tensorflow.Native.UnitTest | |||
| // Export to a graph def so we can import a graph with control dependencies | |||
| graph_def = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle); | |||
| c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Import again, with remapped control dependency, into the same graph | |||
| @@ -338,7 +334,7 @@ namespace Tensorflow.Native.UnitTest | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); | |||
| c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| } | |||
| @@ -380,7 +376,6 @@ namespace Tensorflow.Native.UnitTest | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Import it in a fresh graph with return outputs. | |||
| graph.Dispose(); | |||
| graph = new Graph().as_default(); | |||
| var opts = new ImportGraphDefOptions(); | |||
| opts.AddReturnOutput("feed", 0); | |||
| @@ -401,11 +396,6 @@ namespace Tensorflow.Native.UnitTest | |||
| EXPECT_EQ(0, return_outputs[0].index); | |||
| EXPECT_EQ(scalar, return_outputs[1].oper); | |||
| EXPECT_EQ(0, return_outputs[1].index); | |||
| opts.Dispose(); | |||
| graph_def.Dispose(); | |||
| graph.Dispose(); | |||
| s.Dispose(); | |||
| } | |||
| /// <summary> | |||
| @@ -422,16 +412,14 @@ namespace Tensorflow.Native.UnitTest | |||
| public void ImportGraphMeta() | |||
| { | |||
| var dir = "my-save-dir/"; | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); | |||
| new_saver.restore(sess, dir + "my-model-10000"); | |||
| var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | |||
| var batch_size = tf.size(labels); | |||
| var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor; | |||
| var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, | |||
| logits: logits); | |||
| } | |||
| var sess = tf.Session(); | |||
| var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); | |||
| new_saver.restore(sess, dir + "my-model-10000"); | |||
| var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | |||
| var batch_size = tf.size(labels); | |||
| var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor; | |||
| var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, | |||
| logits: logits); | |||
| } | |||
| } | |||
| } | |||
| @@ -11,7 +11,7 @@ namespace Tensorflow.Native.UnitTest | |||
| /// </summary> | |||
| public class CSession | |||
| { | |||
| private IntPtr session_; | |||
| private SafeSessionHandle session_; | |||
| private List<TF_Output> inputs_ = new List<TF_Output>(); | |||
| private List<Tensor> input_values_ = new List<Tensor>(); | |||
| @@ -22,11 +22,8 @@ namespace Tensorflow.Native.UnitTest | |||
| public CSession(Graph graph, Status s, bool user_XLA = false) | |||
| { | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| var config = new ConfigProto { InterOpParallelismThreads = 4 }; | |||
| session_ = new Session(graph, config, s); | |||
| } | |||
| var config = new ConfigProto { InterOpParallelismThreads = 4 }; | |||
| session_ = new Session(graph, config, s); | |||
| } | |||
| public void SetInputs(Dictionary<Operation, Tensor> inputs) | |||
| @@ -85,7 +82,7 @@ namespace Tensorflow.Native.UnitTest | |||
| c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | |||
| outputs_ptr, output_values_ptr, outputs_.Count, | |||
| targets_ptr, targets_.Count, | |||
| IntPtr.Zero, s.Handle); | |||
| IntPtr.Zero, s); | |||
| s.Check(); | |||
| @@ -14,8 +14,8 @@ namespace Tensorflow.Native.UnitTest.Sessions | |||
| [TestMethod] | |||
| public void Session() | |||
| { | |||
| using var s = new Status(); | |||
| using var graph = new Graph(); | |||
| var s = new Status(); | |||
| var graph = new Graph(); | |||
| // Make a placeholder operation. | |||
| var feed = c_test_util.Placeholder(graph, s); | |||
| @@ -139,45 +139,45 @@ namespace Tensorflow.Native.UnitTest.Tensors | |||
| var feed_out_0 = new TF_Output(feed, 0); | |||
| // Fetch the shape, it should be completely unknown. | |||
| int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
| int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(-1, num_dims); | |||
| // Set the shape to be unknown, expect no change. | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
| EXPECT_EQ(-1, num_dims); | |||
| // Set the shape to be 2 x Unknown | |||
| long[] dims = { 2, -1 }; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
| EXPECT_EQ(2, num_dims); | |||
| // Get the dimension vector appropriately. | |||
| var returned_dims = new long[dims.Length]; | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
| // Set to a new valid shape: [2, 3] | |||
| dims[1] = 3; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| // Fetch and see that the new value is returned. | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
| // Try to set 'unknown' with unknown rank on the shape and see that | |||
| // it doesn't change. | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(2, num_dims); | |||
| EXPECT_EQ(2, (int)returned_dims[0]); | |||
| @@ -187,21 +187,21 @@ namespace Tensorflow.Native.UnitTest.Tensors | |||
| // it doesn't change. | |||
| dims[0] = -1; | |||
| dims[1] = -1; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(2, num_dims); | |||
| EXPECT_EQ(2, (int)returned_dims[0]); | |||
| EXPECT_EQ(3, (int)returned_dims[1]); | |||
| // Try to fetch a shape with the wrong num_dims | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
| // Try to set an invalid shape (cannot change 2x3 to a 2x5). | |||
| dims[1] = 5; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
| // Test for a scalar. | |||
| @@ -209,14 +209,13 @@ namespace Tensorflow.Native.UnitTest.Tensors | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| var three_out_0 = new TF_Output(three, 0); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(0, num_dims); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s.Handle); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
| graph.Exit(); | |||
| s.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -23,7 +23,7 @@ namespace Tensorflow.Native.UnitTest | |||
| c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||
| var op = c_api.TF_FinishOperation(desc, s.Handle); | |||
| var op = c_api.TF_FinishOperation(desc, s); | |||
| s.Check(); | |||
| return op; | |||
| @@ -33,37 +33,29 @@ namespace Tensorflow.Native.UnitTest | |||
| [SuppressMessage("ReSharper", "RedundantAssignment")] | |||
| public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | |||
| { | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| using (var buffer = new Buffer()) | |||
| { | |||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer.Handle, s.Handle); | |||
| attr_value = AttrValue.Parser.ParseFrom(buffer.ToArray()); | |||
| } | |||
| var buffer = new Buffer(); | |||
| return s.Code == TF_Code.TF_OK; | |||
| } | |||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||
| attr_value = AttrValue.Parser.ParseFrom(buffer.ToArray()); | |||
| return s.Code == TF_Code.TF_OK; | |||
| } | |||
| public static GraphDef GetGraphDef(Graph graph) | |||
| { | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| using (var s = new Status()) | |||
| using (var buffer = new Buffer()) | |||
| { | |||
| c_api.TF_GraphToGraphDef(graph, buffer.Handle, s.Handle); | |||
| s.Check(); | |||
| return GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||
| } | |||
| } | |||
| var s = new Status(); | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
| s.Check(); | |||
| return GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||
| } | |||
| public static FunctionDef GetFunctionDef(IntPtr func) | |||
| public static FunctionDef GetFunctionDef(SafeFuncGraphHandle func) | |||
| { | |||
| using var s = new Status(); | |||
| using var buffer = new Buffer(); | |||
| c_api.TF_FunctionToFunctionDef(func, buffer.Handle, s.Handle); | |||
| var s = new Status(); | |||
| var buffer = new Buffer(); | |||
| c_api.TF_FunctionToFunctionDef(func, buffer, s); | |||
| s.Check(true); | |||
| var func_def = FunctionDef.Parser.ParseFrom(buffer.ToArray()); | |||
| return func_def; | |||
| @@ -192,7 +184,7 @@ namespace Tensorflow.Native.UnitTest | |||
| OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | |||
| var neg_input = new TF_Output(n, 0); | |||
| c_api.TF_AddInput(desc, neg_input); | |||
| var op = c_api.TF_FinishOperation(desc, s.Handle); | |||
| var op = c_api.TF_FinishOperation(desc, s); | |||
| s.Check(); | |||
| return op; | |||
| @@ -210,7 +202,7 @@ namespace Tensorflow.Native.UnitTest | |||
| c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||
| } | |||
| var op = c_api.TF_FinishOperation(desc, s.Handle); | |||
| var op = c_api.TF_FinishOperation(desc, s); | |||
| s.Check(); | |||
| return op; | |||
| @@ -222,10 +214,10 @@ namespace Tensorflow.Native.UnitTest | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| var desc = c_api.TF_NewOperation(graph, "Const", name); | |||
| c_api.TF_SetAttrTensor(desc, "value", t, s.Handle); | |||
| c_api.TF_SetAttrTensor(desc, "value", t, s); | |||
| s.Check(); | |||
| c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||
| var op = c_api.TF_FinishOperation(desc, s.Handle); | |||
| var op = c_api.TF_FinishOperation(desc, s); | |||
| s.Check(); | |||
| return op; | |||
| @@ -17,10 +17,8 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| public void ImportGraph() | |||
| { | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); | |||
| } | |||
| var sess = tf.Session(); | |||
| var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); | |||
| //tf.train.export_meta_graph(filename: "linear_regression.meta.bin"); | |||
| // import meta | |||
| @@ -60,14 +58,12 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| // Add ops to save and restore all the variables. | |||
| var saver = tf.train.Saver(); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| sess.run(init_op); | |||
| var sess = tf.Session(); | |||
| sess.run(init_op); | |||
| // Save the variables to disk. | |||
| var save_path = saver.save(sess, "/tmp/model1.ckpt"); | |||
| Console.WriteLine($"Model saved in path: {save_path}"); | |||
| } | |||
| // Save the variables to disk. | |||
| var save_path = saver.save(sess, "/tmp/model1.ckpt"); | |||
| Console.WriteLine($"Model saved in path: {save_path}"); | |||
| } | |||
| public void Save2() | |||
| @@ -84,17 +80,15 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| // Add ops to save and restore all the variables. | |||
| var saver = tf.train.Saver(); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| sess.run(init_op); | |||
| // o some work with the model. | |||
| inc_v1.op.run(); | |||
| dec_v2.op.run(); | |||
| // Save the variables to disk. | |||
| var save_path = saver.save(sess, "/tmp/model2.ckpt"); | |||
| Console.WriteLine($"Model saved in path: {save_path}"); | |||
| } | |||
| var sess = tf.Session(); | |||
| sess.run(init_op); | |||
| // o some work with the model. | |||
| inc_v1.op.run(); | |||
| dec_v2.op.run(); | |||
| // Save the variables to disk. | |||
| var save_path = saver.save(sess, "/tmp/model2.ckpt"); | |||
| Console.WriteLine($"Model saved in path: {save_path}"); | |||
| } | |||
| } | |||
| } | |||
| @@ -57,12 +57,10 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
| var input = tf.placeholder(TF_DataType.TF_FLOAT, new Shape(6)); | |||
| var scan = tf.scan(fn, input); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| sess.run(tf.global_variables_initializer()); | |||
| var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6))); | |||
| Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>()); | |||
| } | |||
| var sess = tf.Session(); | |||
| sess.run(tf.global_variables_initializer()); | |||
| var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6))); | |||
| Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>()); | |||
| } | |||
| } | |||
| } | |||
| @@ -196,23 +196,21 @@ namespace TensorFlowNET.UnitTest | |||
| // return self._eval_helper(tensors) | |||
| // else: | |||
| { | |||
| using (var sess = tf.Session()) | |||
| var sess = tf.Session(); | |||
| var ndarray = tensor.eval(sess); | |||
| if (typeof(T) == typeof(double)) | |||
| { | |||
| var ndarray = tensor.eval(sess); | |||
| if (typeof(T) == typeof(double)) | |||
| { | |||
| double x = ndarray; | |||
| result = x; | |||
| } | |||
| else if (typeof(T) == typeof(int)) | |||
| { | |||
| int x = ndarray; | |||
| result = x; | |||
| } | |||
| else | |||
| { | |||
| result = ndarray; | |||
| } | |||
| double x = ndarray; | |||
| result = x; | |||
| } | |||
| else if (typeof(T) == typeof(int)) | |||
| { | |||
| int x = ndarray; | |||
| result = x; | |||
| } | |||
| else | |||
| { | |||
| result = ndarray; | |||
| } | |||
| return (T)result; | |||
| @@ -28,7 +28,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| public void DeleteStatus() | |||
| { | |||
| var s = new Status(); | |||
| s.Dispose(); | |||
| } | |||
| } | |||
| } | |||