diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 8c6248e3..97a95e95 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Eager; using Tensorflow.Operations; namespace Tensorflow @@ -259,7 +260,6 @@ namespace Tensorflow public Tensor sub(Tx a, Ty b, string name = null) => gen_math_ops.sub(a, b, name: name); - public Tensor divide(Tensor a, Tensor b) => a / b; @@ -348,6 +348,9 @@ namespace Tensorflow public Tensor minimum(T1 x, T2 y, string name = null) => gen_math_ops.minimum(x, y, name: name); + public Tensor multiply(Tensor x, Tensor y, string name = null) + => gen_math_ops.mul(x, y, name: name); + /// /// return x * y /// diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index d7474e7d..d85ea4c8 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow.Eager { @@ -9,41 +10,44 @@ namespace Tensorflow.Eager { Status status = new Status(); TFE_TensorHandle tfe_tensor_handle; + public IntPtr EagerTensorHandle { get; set; } + public EagerTensor(IntPtr handle) : base(handle) { tfe_tensor_handle = handle; _handle = c_api.TFE_TensorHandleResolve(handle, status); - _id = ops.uid(); } public EagerTensor(string value, string device_name) : base(value) { tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); - _id = ops.uid(); } public EagerTensor(int value, string device_name) : base(value) { tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); - _id = ops.uid(); + EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle); + } + + public EagerTensor(float value, string device_name) : base(value) + { + tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); + EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle); } public EagerTensor(float[] value, string device_name) : base(value) { tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); - _id = ops.uid(); } public EagerTensor(double[] value, string device_name) : base(value) { tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); - _id = ops.uid(); } public EagerTensor(NDArray value, string device_name) : base(value) { tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); - _id = ops.uid(); } public override string ToString() diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 4d9c2f32..6e5a81bb 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -7,6 +7,12 @@ namespace Tensorflow { public partial class c_api { + [DllImport(TensorFlowLibName)] + public static extern void TFE_RegisterGradientFunction(_gradient_function_callback callbackPointer); + + [UnmanagedFunctionPointer(CallingConvention.StdCall)] + public delegate void _gradient_function_callback(string op_name, int num_inputs, IntPtr attrs, int num_attrs); + /// /// Return a new options object. /// @@ -186,6 +192,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern TFE_TensorHandle TFE_NewTensorHandle(IntPtr t, IntPtr status); + [DllImport(TensorFlowLibName)] + public static extern TFE_TensorHandle TFE_EagerTensorFromHandle(IntPtr ctx, IntPtr h); + /// /// Sets the default execution mode (sync/async). Note that this can be /// overridden per thread using TFE_ContextSetExecutorForThread. @@ -312,15 +321,21 @@ namespace Tensorflow public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); [DllImport(TensorFlowLibName)] - public static extern void TFE_Test(); + public static extern IntPtr TFE_FastPathExecute(IntPtr ctx, + string device_name, + string op_name, + string name, + IntPtr[] args, + int input_size, + IntPtr status); [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); [DllImport(TensorFlowLibName)] - public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor, int tensor_id); + public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor); [DllImport(TensorFlowLibName)] - public static extern void TFE_TapeGradient(IntPtr tape, long[] targetTensorIds, IntPtr[] target, long[] sourcesTensorIds, IntPtr status); + public static extern void TFE_TapeGradient(IntPtr tape, IntPtr[] target, IntPtr sources, IntPtr status); } } diff --git a/src/TensorFlowNET.Core/Gradients/GradientActor.cs b/src/TensorFlowNET.Core/Gradients/GradientActor.cs index f650aa9e..e8be5dae 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientActor.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientActor.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow.Gradients @@ -53,14 +54,16 @@ namespace Tensorflow.Gradients /// public void watch(Tensor x) { - _tape.watch(x); + _tape.watch(x as EagerTensor); } public Tensor gradient(Tensor target, Tensor sources) { - c_api.TFE_Test(); - //using (var status = new Status()) - //c_api.TFE_TapeGradient(_tape, new long[] { target.Id }, status); + using (var status = new Status()) + { + c_api.TFE_TapeGradient(_tape, new IntPtr[] { target }, IntPtr.Zero, status); + } + return null; } diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs index a61898fe..f47616dd 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Eager; namespace Tensorflow.Gradients { @@ -14,9 +15,9 @@ namespace Tensorflow.Gradients _handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables); } - public void watch(Tensor x) + public void watch(EagerTensor x) { - c_api.TFE_TapeWatch(_handle, x, x.Id); + c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle); } public static bool IsDtypeTrainable(DataType dtype) diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 5597bfc8..84914a76 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -192,6 +192,28 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor add(Tensor x, Tensor y, string name = null) + { + if (tf.context.executing_eagerly()) + { + using (var status = new Status()) + { + var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Add", name, new IntPtr[] + { + (x as EagerTensor).EagerTensorHandle, + (y as EagerTensor).EagerTensorHandle + }, 2, status); + status.Check(true); + return new EagerTensor(_result); + } + } + + var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); + + return _op.output; + } + public static Tensor add(Tx x, Ty y, string name = null) { if (tf.context.executing_eagerly()) @@ -593,6 +615,28 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor sub(Tensor x, Tensor y, string name = null) + { + if (tf.context.executing_eagerly()) + { + using (var status = new Status()) + { + var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Sub", name, new IntPtr[] + { + (x as EagerTensor).EagerTensorHandle, + (y as EagerTensor).EagerTensorHandle + }, 2, status); + status.Check(true); + return new EagerTensor(_result); + } + } + + var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); + + return _op.output; + } + public static Tensor sub(Tx x, Ty y, string name = null) { if (tf.context.executing_eagerly()) @@ -667,6 +711,28 @@ namespace Tensorflow return _op.output; } + public static Tensor mul(Tensor x, Tensor y, string name = null) + { + if (tf.context.executing_eagerly()) + { + using (var status = new Status()) + { + var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Mul", name, new IntPtr[] + { + (x as EagerTensor).EagerTensorHandle, + (y as EagerTensor).EagerTensorHandle + }, 2, status); + status.Check(true); + return new EagerTensor(_result); + } + } + + var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); + + return _op.output; + } + public static Tensor mul(Tx x, Ty y, string name = null) { if (tf.context.executing_eagerly()) @@ -693,8 +759,17 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "RealDiv", name, null, x, y); - return _result; + using (var status = new Status()) + { + var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "RealDiv", name, new IntPtr[] + { + (x as EagerTensor).EagerTensorHandle, + (y as EagerTensor).EagerTensorHandle + }, 2, status); + status.Check(true); + return new EagerTensor(_result); + } } var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y }); diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index 57d1d83c..520ff9e4 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -4,7 +4,7 @@ netstandard2.0 TensorFlow.NET Tensorflow - 2.01.0 + 2.2.0 0.20.0 8.0 Haiping Chen, Meinrad Recheis, Eli Belash diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj index 5dcad04b..266684d8 100644 --- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj @@ -18,7 +18,7 @@ - + diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj index 0211b584..a7430e7e 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj @@ -31,8 +31,8 @@ - - + + diff --git a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj index b646a28b..b52a923b 100644 --- a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ b/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj @@ -8,9 +8,9 @@ - - - + + + all runtime; build; native; contentfiles; analyzers; buildtransitive