| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Tensorflow.Eager; | |||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -259,7 +260,6 @@ namespace Tensorflow | |||||
| public Tensor sub<Tx, Ty>(Tx a, Ty b, string name = null) | public Tensor sub<Tx, Ty>(Tx a, Ty b, string name = null) | ||||
| => gen_math_ops.sub(a, b, name: name); | => gen_math_ops.sub(a, b, name: name); | ||||
| public Tensor divide(Tensor a, Tensor b) | public Tensor divide(Tensor a, Tensor b) | ||||
| => a / b; | => a / b; | ||||
| @@ -348,6 +348,9 @@ namespace Tensorflow | |||||
| public Tensor minimum<T1, T2>(T1 x, T2 y, string name = null) | public Tensor minimum<T1, T2>(T1 x, T2 y, string name = null) | ||||
| => gen_math_ops.minimum(x, y, name: name); | => gen_math_ops.minimum(x, y, name: name); | ||||
| public Tensor multiply(Tensor x, Tensor y, string name = null) | |||||
| => gen_math_ops.mul(x, y, name: name); | |||||
| /// <summary> | /// <summary> | ||||
| /// return x * y | /// return x * y | ||||
| /// </summary> | /// </summary> | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| @@ -9,41 +10,44 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| Status status = new Status(); | Status status = new Status(); | ||||
| TFE_TensorHandle tfe_tensor_handle; | TFE_TensorHandle tfe_tensor_handle; | ||||
| public IntPtr EagerTensorHandle { get; set; } | |||||
| public EagerTensor(IntPtr handle) : base(handle) | public EagerTensor(IntPtr handle) : base(handle) | ||||
| { | { | ||||
| tfe_tensor_handle = handle; | tfe_tensor_handle = handle; | ||||
| _handle = c_api.TFE_TensorHandleResolve(handle, status); | _handle = c_api.TFE_TensorHandleResolve(handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public EagerTensor(string value, string device_name) : base(value) | public EagerTensor(string value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public EagerTensor(int value, string device_name) : base(value) | public EagerTensor(int value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
| _id = ops.uid(); | |||||
| EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle); | |||||
| } | |||||
| public EagerTensor(float value, string device_name) : base(value) | |||||
| { | |||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
| EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle); | |||||
| } | } | ||||
| public EagerTensor(float[] value, string device_name) : base(value) | public EagerTensor(float[] value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public EagerTensor(double[] value, string device_name) : base(value) | public EagerTensor(double[] value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public EagerTensor(NDArray value, string device_name) : base(value) | public EagerTensor(NDArray value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public override string ToString() | public override string ToString() | ||||
| @@ -7,6 +7,12 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class c_api | public partial class c_api | ||||
| { | { | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_RegisterGradientFunction(_gradient_function_callback callbackPointer); | |||||
| [UnmanagedFunctionPointer(CallingConvention.StdCall)] | |||||
| public delegate void _gradient_function_callback(string op_name, int num_inputs, IntPtr attrs, int num_attrs); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return a new options object. | /// Return a new options object. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -186,6 +192,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_TensorHandle TFE_NewTensorHandle(IntPtr t, IntPtr status); | public static extern TFE_TensorHandle TFE_NewTensorHandle(IntPtr t, IntPtr status); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern TFE_TensorHandle TFE_EagerTensorFromHandle(IntPtr ctx, IntPtr h); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sets the default execution mode (sync/async). Note that this can be | /// Sets the default execution mode (sync/async). Note that this can be | ||||
| /// overridden per thread using TFE_ContextSetExecutorForThread. | /// overridden per thread using TFE_ContextSetExecutorForThread. | ||||
| @@ -312,15 +321,21 @@ namespace Tensorflow | |||||
| public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); | public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_Test(); | |||||
| public static extern IntPtr TFE_FastPathExecute(IntPtr ctx, | |||||
| string device_name, | |||||
| string op_name, | |||||
| string name, | |||||
| IntPtr[] args, | |||||
| int input_size, | |||||
| IntPtr status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor, int tensor_id); | |||||
| public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_TapeGradient(IntPtr tape, long[] targetTensorIds, IntPtr[] target, long[] sourcesTensorIds, IntPtr status); | |||||
| public static extern void TFE_TapeGradient(IntPtr tape, IntPtr[] target, IntPtr sources, IntPtr status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
| @@ -53,14 +54,16 @@ namespace Tensorflow.Gradients | |||||
| /// <param name="x"></param> | /// <param name="x"></param> | ||||
| public void watch(Tensor x) | public void watch(Tensor x) | ||||
| { | { | ||||
| _tape.watch(x); | |||||
| _tape.watch(x as EagerTensor); | |||||
| } | } | ||||
| public Tensor gradient(Tensor target, Tensor sources) | public Tensor gradient(Tensor target, Tensor sources) | ||||
| { | { | ||||
| c_api.TFE_Test(); | |||||
| //using (var status = new Status()) | |||||
| //c_api.TFE_TapeGradient(_tape, new long[] { target.Id }, status); | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.TFE_TapeGradient(_tape, new IntPtr[] { target }, IntPtr.Zero, status); | |||||
| } | |||||
| return null; | return null; | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
| { | { | ||||
| @@ -14,9 +15,9 @@ namespace Tensorflow.Gradients | |||||
| _handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables); | _handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables); | ||||
| } | } | ||||
| public void watch(Tensor x) | |||||
| public void watch(EagerTensor x) | |||||
| { | { | ||||
| c_api.TFE_TapeWatch(_handle, x, x.Id); | |||||
| c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle); | |||||
| } | } | ||||
| public static bool IsDtypeTrainable(DataType dtype) | public static bool IsDtypeTrainable(DataType dtype) | ||||
| @@ -192,6 +192,28 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor add(Tensor x, Tensor y, string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| using (var status = new Status()) | |||||
| { | |||||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Add", name, new IntPtr[] | |||||
| { | |||||
| (x as EagerTensor).EagerTensorHandle, | |||||
| (y as EagerTensor).EagerTensorHandle | |||||
| }, 2, status); | |||||
| status.Check(true); | |||||
| return new EagerTensor(_result); | |||||
| } | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | |||||
| return _op.output; | |||||
| } | |||||
| public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| @@ -593,6 +615,28 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor sub(Tensor x, Tensor y, string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| using (var status = new Status()) | |||||
| { | |||||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Sub", name, new IntPtr[] | |||||
| { | |||||
| (x as EagerTensor).EagerTensorHandle, | |||||
| (y as EagerTensor).EagerTensorHandle | |||||
| }, 2, status); | |||||
| status.Check(true); | |||||
| return new EagerTensor(_result); | |||||
| } | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); | |||||
| return _op.output; | |||||
| } | |||||
| public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| @@ -667,6 +711,28 @@ namespace Tensorflow | |||||
| return _op.output; | return _op.output; | ||||
| } | } | ||||
| public static Tensor mul(Tensor x, Tensor y, string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| using (var status = new Status()) | |||||
| { | |||||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Mul", name, new IntPtr[] | |||||
| { | |||||
| (x as EagerTensor).EagerTensorHandle, | |||||
| (y as EagerTensor).EagerTensorHandle | |||||
| }, 2, status); | |||||
| status.Check(true); | |||||
| return new EagerTensor(_result); | |||||
| } | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | |||||
| return _op.output; | |||||
| } | |||||
| public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| @@ -693,8 +759,17 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| { | { | ||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "RealDiv", name, null, x, y); | |||||
| return _result; | |||||
| using (var status = new Status()) | |||||
| { | |||||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "RealDiv", name, new IntPtr[] | |||||
| { | |||||
| (x as EagerTensor).EagerTensorHandle, | |||||
| (y as EagerTensor).EagerTensorHandle | |||||
| }, 2, status); | |||||
| status.Check(true); | |||||
| return new EagerTensor(_result); | |||||
| } | |||||
| } | } | ||||
| var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y }); | ||||
| @@ -4,7 +4,7 @@ | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <TargetTensorFlow>2.01.0</TargetTensorFlow> | |||||
| <TargetTensorFlow>2.2.0</TargetTensorFlow> | |||||
| <Version>0.20.0</Version> | <Version>0.20.0</Version> | ||||
| <LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
| @@ -18,7 +18,7 @@ | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="BenchmarkDotNet" Version="0.12.0" /> | |||||
| <PackageReference Include="BenchmarkDotNet" Version="0.12.1" /> | |||||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.1.0" /> | <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.1.0" /> | ||||
| <PackageReference Include="TensorFlow.NET" Version="0.15.1" /> | <PackageReference Include="TensorFlow.NET" Version="0.15.1" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -31,8 +31,8 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="FluentAssertions" Version="5.10.3" /> | <PackageReference Include="FluentAssertions" Version="5.10.3" /> | ||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" /> | ||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.0" /> | |||||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.0" /> | |||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.1" /> | |||||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.1" /> | |||||
| <PackageReference Include="NumSharp.Lite" Version="0.1.7" /> | <PackageReference Include="NumSharp.Lite" Version="0.1.7" /> | ||||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.1.0" /> | <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.1.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -8,9 +8,9 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" /> | ||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.0" /> | |||||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.0" /> | |||||
| <PackageReference Include="coverlet.collector" Version="1.2.0"> | |||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.1" /> | |||||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.1" /> | |||||
| <PackageReference Include="coverlet.collector" Version="1.2.1"> | |||||
| <PrivateAssets>all</PrivateAssets> | <PrivateAssets>all</PrivateAssets> | ||||
| <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||||
| </PackageReference> | </PackageReference> | ||||