| @@ -53,8 +53,7 @@ namespace Tensorflow.Eager | |||
| { | |||
| object value = null; | |||
| byte isList = 0; | |||
| using var status = new Status(); | |||
| var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, status); | |||
| var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, tf.status); | |||
| switch (attrType) | |||
| { | |||
| case TF_AttrType.TF_ATTR_BOOL: | |||
| @@ -22,13 +22,13 @@ namespace Tensorflow.Eager | |||
| public EagerTensor(string value, string device_name) : base(value) | |||
| { | |||
| EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, status); | |||
| EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status); | |||
| Resolve(); | |||
| } | |||
| public EagerTensor(NDArray value, string device_name) : base(value) | |||
| { | |||
| EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, status); | |||
| EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status); | |||
| Resolve(); | |||
| } | |||
| @@ -37,7 +37,7 @@ namespace Tensorflow.Eager | |||
| _id = get_uid(); | |||
| if (_handle == IntPtr.Zero) | |||
| _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, status); | |||
| _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.status); | |||
| //print($"new Tensor {Id} {_handle.ToString("x16")}"); | |||
| //print($"new TensorHandle {Id} {EagerTensorHandle.ToString("x16")}"); | |||
| @@ -8,26 +8,23 @@ namespace Tensorflow.Eager | |||
| { | |||
| public partial class EagerTensor : Tensor | |||
| { | |||
| Status status = new Status(); | |||
| public IntPtr EagerTensorHandle; | |||
| public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, status)); | |||
| public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status)); | |||
| public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, status); | |||
| public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.status); | |||
| public static int GetRank(IntPtr handle) | |||
| { | |||
| var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | |||
| using var status = new Status(); | |||
| return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status); | |||
| return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.status); | |||
| } | |||
| public static int[] GetDims(IntPtr handle) | |||
| { | |||
| var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | |||
| using var status = new Status(); | |||
| var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status)]; | |||
| var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.status)]; | |||
| for (int i = 0; i < dims.Length; i++) | |||
| dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, status); | |||
| dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.status); | |||
| return dims; | |||
| } | |||
| @@ -512,7 +512,7 @@ namespace Tensorflow | |||
| public TensorShape GetTensorShape(TF_Output output) | |||
| { | |||
| var status = new Status(); | |||
| var status = tf.status; | |||
| var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); | |||
| status.Check(); | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -30,11 +31,8 @@ namespace Tensorflow | |||
| public int InputListLength(string name) | |||
| { | |||
| int num = 0; | |||
| using(var status = new Status()) | |||
| { | |||
| num = c_api.TF_OperationInputListLength(_handle, name, status); | |||
| status.Check(true); | |||
| } | |||
| num = c_api.TF_OperationInputListLength(_handle, name, tf.status); | |||
| tf.status.Check(true); | |||
| return num; | |||
| } | |||
| public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||
| @@ -28,12 +28,8 @@ namespace Tensorflow | |||
| public int OutputListLength(string name) | |||
| { | |||
| int num = 0; | |||
| using (var status = new Status()) | |||
| { | |||
| num = c_api.TF_OperationOutputListLength(_handle, name, status); | |||
| status.Check(true); | |||
| } | |||
| int num = c_api.TF_OperationOutputListLength(_handle, name, tf.status); | |||
| tf.status.Check(true); | |||
| return num; | |||
| } | |||
| @@ -20,6 +20,7 @@ using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -233,14 +234,13 @@ namespace Tensorflow | |||
| AttrValue x = null; | |||
| lock (Locks.ProcessWide) | |||
| using (var status = new Status()) | |||
| using (var buf = new Buffer()) | |||
| { | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||
| status.Check(true); | |||
| { | |||
| using var buf = new Buffer(); | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.status); | |||
| tf.status.Check(true); | |||
| x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||
| } | |||
| x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||
| } | |||
| string oneof_value = x.ValueCase.ToString(); | |||
| if (string.IsNullOrEmpty(oneof_value)) | |||
| @@ -295,11 +295,10 @@ namespace Tensorflow | |||
| // after the c_api call next time _inputs is accessed | |||
| // the updated inputs are reloaded from the c_api | |||
| lock (Locks.ProcessWide) | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.UpdateEdge(_graph, output, input, status); | |||
| c_api.UpdateEdge(_graph, output, input, tf.status); | |||
| //var updated_inputs = inputs; | |||
| status.Check(); | |||
| tf.status.Check(); | |||
| } | |||
| } | |||
| @@ -43,30 +43,26 @@ namespace Tensorflow | |||
| allow_broadcast: false); | |||
| public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||
| { | |||
| dtype = dtype.as_base_dtype(); | |||
| return tf_with(ops.name_scope(name, "zeros", shape), scope => | |||
| => tf_with(ops.name_scope(name, "zeros", shape), scope => | |||
| { | |||
| dtype = dtype.as_base_dtype(); | |||
| name = scope; | |||
| var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||
| Tensor zeros = null; | |||
| switch (dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| return _constant_if_small(false, shape, dtype, name); | |||
| case TF_DataType.TF_DOUBLE: | |||
| return _constant_if_small(0.0D, shape, dtype, name); | |||
| zeros = constant(0d); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| return _constant_if_small(0.0F, shape, dtype, name); | |||
| case TF_DataType.TF_INT64: | |||
| return _constant_if_small(0L, shape, dtype, name); | |||
| case TF_DataType.TF_INT32: | |||
| return _constant_if_small(0, shape, dtype, name); | |||
| case TF_DataType.TF_INT8: | |||
| return _constant_if_small<byte>(0, shape, dtype, name); | |||
| zeros = constant(0f); | |||
| break; | |||
| default: | |||
| throw new TypeError("can't find type for zeros"); | |||
| zeros = constant(0); | |||
| break; | |||
| } | |||
| return fill(shape_tensor, zeros, name: name); | |||
| }); | |||
| } | |||
| public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | |||
| { | |||
| @@ -22,7 +22,7 @@ using System.Linq; | |||
| using System.Numerics; | |||
| using System.Text; | |||
| using Google.Protobuf; | |||
| using NumSharp.Backends; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| @@ -236,7 +236,7 @@ namespace Tensorflow | |||
| // Ensure any changes to the graph are reflected in the runtime. | |||
| _extend_graph(); | |||
| var status = new Status(); | |||
| var status = tf.status; | |||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
| @@ -46,7 +46,7 @@ namespace Tensorflow | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| var graph = c_api.TF_NewGraph(); | |||
| var status = new Status(); | |||
| using var status = new Status(); | |||
| var opt = new SessionOptions(); | |||
| var tags = new string[] {"serve"}; | |||
| @@ -66,7 +66,6 @@ namespace Tensorflow | |||
| status.Check(true); | |||
| } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | |||
| { | |||
| status = new Status(); | |||
| sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
| IntPtr.Zero, | |||
| Path.GetFullPath(path), | |||
| @@ -13,14 +13,12 @@ namespace Tensorflow | |||
| public class EagerTensorV2 : DisposableObject, ITensor | |||
| { | |||
| IntPtr EagerTensorHandle; | |||
| public string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, status)); | |||
| static Status status = new Status(); | |||
| public string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status)); | |||
| public EagerTensorV2(IntPtr handle) | |||
| { | |||
| EagerTensorHandle = c_api.TFE_EagerTensorHandle(handle); | |||
| _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, status); | |||
| _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.status); | |||
| } | |||
| public unsafe EagerTensorV2(NDArray nd, string device_name = "") | |||
| @@ -40,7 +38,7 @@ namespace Tensorflow | |||
| }, IntPtr.Zero); | |||
| EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, status); | |||
| EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status); | |||
| } | |||
| /*public unsafe EagerTensorV2(float[,] value) | |||
| @@ -21,6 +21,7 @@ using System.Globalization; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using NumSharp.Utilities; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -69,11 +70,8 @@ namespace Tensorflow | |||
| IntPtr stringStartAddress = IntPtr.Zero; | |||
| UIntPtr dstLen = UIntPtr.Zero; | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, status); | |||
| status.Check(true); | |||
| } | |||
| c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status); | |||
| tf.status.Check(true); | |||
| var dstLenInt = checked((int) dstLen); | |||
| var value = Encoding.UTF8.GetString((byte*) stringStartAddress, dstLenInt); | |||
| @@ -451,7 +451,6 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public unsafe Tensor(string str) | |||
| { | |||
| var status = new Status(); | |||
| var buffer = Encoding.UTF8.GetBytes(str); | |||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||
| @@ -460,9 +459,9 @@ namespace Tensorflow | |||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||
| Marshal.WriteInt64(tensor, 0); | |||
| fixed (byte* src = buffer) | |||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, status); | |||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, tf.status); | |||
| _handle = handle; | |||
| status.Check(true); | |||
| tf.status.Check(true); | |||
| } | |||
| public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | |||
| @@ -483,10 +482,8 @@ namespace Tensorflow | |||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||
| Marshal.WriteInt64(tensor, 0); | |||
| var status = new Status(); | |||
| c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||
| status.Check(true); | |||
| c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, tf.status); | |||
| tf.status.Check(true); | |||
| _handle = handle; | |||
| } else | |||
| { | |||
| @@ -498,11 +495,10 @@ namespace Tensorflow | |||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||
| Marshal.WriteInt64(tensor, 0); | |||
| var status = new Status(); | |||
| fixed (byte* src = buffer) | |||
| c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||
| c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, tf.status); | |||
| status.Check(true); | |||
| tf.status.Check(true); | |||
| _handle = handle; | |||
| } | |||
| @@ -607,11 +603,10 @@ namespace Tensorflow | |||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||
| Marshal.WriteInt64(tensor, 0); | |||
| var status = new Status(); | |||
| fixed (byte* src = buffer) | |||
| c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||
| c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(long)), size, tf.status); | |||
| status.Check(true); | |||
| tf.status.Check(true); | |||
| return handle; | |||
| } | |||
| @@ -3,7 +3,7 @@ using NumSharp.Backends; | |||
| using NumSharp.Backends.Unmanaged; | |||
| using NumSharp.Utilities; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using static Tensorflow.Binding; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| @@ -237,18 +237,15 @@ namespace Tensorflow | |||
| var src = c_api.TF_TensorData(_handle); | |||
| var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | |||
| src += (int)(size * 8); | |||
| using (var status = new Status()) | |||
| for (int i = 0; i < buffer.Length; i++) | |||
| { | |||
| for (int i = 0; i < buffer.Length; i++) | |||
| { | |||
| IntPtr dst = IntPtr.Zero; | |||
| UIntPtr dstLen = UIntPtr.Zero; | |||
| var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); | |||
| status.Check(true); | |||
| buffer[i] = new byte[(int)dstLen]; | |||
| Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||
| src += (int)read; | |||
| } | |||
| IntPtr dst = IntPtr.Zero; | |||
| UIntPtr dstLen = UIntPtr.Zero; | |||
| var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, tf.status); | |||
| tf.status.Check(true); | |||
| buffer[i] = new byte[(int)dstLen]; | |||
| Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||
| src += (int)read; | |||
| } | |||
| var _str = new string[buffer.Length]; | |||
| @@ -22,7 +22,7 @@ using System.Globalization; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Framework; | |||
| namespace Tensorflow | |||
| @@ -109,11 +109,7 @@ namespace Tensorflow | |||
| if (_handle == IntPtr.Zero) | |||
| { | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||
| status.Check(); | |||
| } | |||
| c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.status); | |||
| } | |||
| else | |||
| { | |||
| @@ -126,15 +122,12 @@ namespace Tensorflow | |||
| set | |||
| { | |||
| using (var status = new Status()) | |||
| { | |||
| if (value == null) | |||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, status); | |||
| else | |||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||
| if (value == null) | |||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.status); | |||
| else | |||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, tf.status); | |||
| status.Check(true); | |||
| } | |||
| tf.status.Check(true); | |||
| } | |||
| } | |||
| @@ -178,13 +171,9 @@ namespace Tensorflow | |||
| { | |||
| if (_handle == IntPtr.Zero) | |||
| { | |||
| using (var status = new Status()) | |||
| { | |||
| var output = _as_tf_output(); | |||
| int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); | |||
| status.Check(); | |||
| return ndim; | |||
| } | |||
| var output = _as_tf_output(); | |||
| int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.status); | |||
| return ndim; | |||
| } | |||
| return c_api.TF_NumDims(_handle); | |||
| @@ -176,30 +176,29 @@ namespace Tensorflow | |||
| throw new NotImplementedException("_create_c_op"); | |||
| } | |||
| using (var status = new Status()) | |||
| var status = tf.status; | |||
| // Add control inputs | |||
| foreach (var control_input in control_inputs) | |||
| c_api.TF_AddControlInput(op_desc, control_input); | |||
| // Add attrs | |||
| foreach (var attr in node_def.Attr) | |||
| { | |||
| // Add control inputs | |||
| foreach (var control_input in control_inputs) | |||
| c_api.TF_AddControlInput(op_desc, control_input); | |||
| // Add attrs | |||
| foreach (var attr in node_def.Attr) | |||
| { | |||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||
| var protoHandle = Marshal.AllocHGlobal(bytes.Length); | |||
| Marshal.Copy(bytes, 0, protoHandle, bytes.Length); | |||
| uint len = (uint)bytes.Length; | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); | |||
| status.Check(true); | |||
| Marshal.FreeHGlobal(protoHandle); | |||
| } | |||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||
| var protoHandle = Marshal.AllocHGlobal(bytes.Length); | |||
| Marshal.Copy(bytes, 0, protoHandle, bytes.Length); | |||
| uint len = (uint)bytes.Length; | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); | |||
| status.Check(true); | |||
| Marshal.FreeHGlobal(protoHandle); | |||
| } | |||
| var c_op = c_api.TF_FinishOperation(op_desc, status); | |||
| var c_op = c_api.TF_FinishOperation(op_desc, status); | |||
| status.Check(true); | |||
| status.Check(true); | |||
| return c_op; | |||
| } | |||
| return c_op; | |||
| } | |||
| } | |||
| @@ -22,7 +22,6 @@ using System.Runtime.InteropServices; | |||
| using System.Threading; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Gradients; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -42,11 +41,12 @@ namespace Tensorflow | |||
| public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | |||
| public Status status = new Status(); | |||
| public OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| public Context context = new Context(new ContextOptions(), new Status()); | |||
| public Execute _execute = new Execute(); | |||
| public IEagerRunner Runner = new EagerRunner(); | |||
| public Context context = new Context(new ContextOptions(), new Status()); | |||
| public tensorflow() | |||
| { | |||
| enable_eager_execution(); | |||
| @@ -96,14 +96,14 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| public void ZerosConst() | |||
| { | |||
| // small size | |||
| var tensor = tf.zeros(new Shape(3, 2), tf.int32, "small"); | |||
| var tensor = tf.zeros((3, 2), tf.int32, "small"); | |||
| Assert.AreEqual(tensor.shape[0], 3); | |||
| Assert.AreEqual(tensor.shape[1], 2); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, tensor.numpy().ToArray<int>())); | |||
| // big size | |||
| tensor = tf.zeros(new Shape(200, 100), tf.int32, "big"); | |||
| tensor = tf.zeros((200, 100), tf.int32, "big"); | |||
| Assert.AreEqual(tensor.shape[0], 200); | |||
| Assert.AreEqual(tensor.shape[1], 100); | |||
| @@ -35,7 +35,26 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| var dz_dx = tape.gradient(z, x); | |||
| var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.numpy().ToArray<float>(), expected)); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | |||
| } | |||
| [TestMethod] | |||
| public void PersistentTape() | |||
| { | |||
| var x = tf.ones((2, 2)); | |||
| using var tape = tf.GradientTape(persistent: true); | |||
| tape.watch(x); | |||
| var y = tf.reduce_sum(x); | |||
| var z = tf.multiply(y, y); | |||
| var dz_dx = tape.gradient(z, x); | |||
| var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | |||
| var dz_dy = tape.gradient(z, y); | |||
| expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | |||
| } | |||
| } | |||
| } | |||