| @@ -9,7 +9,7 @@ | |||||
| [](https://996.icu/#/en_US) | [](https://996.icu/#/en_US) | ||||
| [](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab) | [](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab) | ||||
| *master branch is based on tensorflow 2.1 now, v0.15-tensorflow1.15 is from tensorflow1.15.* | |||||
| *master branch is based on tensorflow 2.2 now, v0.15-tensorflow1.15 is from tensorflow1.15.* | |||||
| TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | ||||
| @@ -28,7 +28,7 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr | |||||
| ### How to use | ### How to use | ||||
| | TensorFlow | tf 1.13 | tf 1.14 | tf 1.15 | tf 2.0 | | |||||
| | TensorFlow | tf 1.13 | tf 1.14 | tf 1.15 | tf 2.2 | | |||||
| | ----------- | ------- | ------- | ------- | ------ | | | ----------- | ------- | ------- | ------- | ------ | | ||||
| | tf.net 0.20 | | | x | x | | | tf.net 0.20 | | | x | x | | ||||
| | tf.net 0.15 | | x | x | | | | tf.net 0.15 | | x | x | | | ||||
| @@ -13,31 +13,37 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| tfe_tensor_handle = handle; | tfe_tensor_handle = handle; | ||||
| _handle = c_api.TFE_TensorHandleResolve(handle, status); | _handle = c_api.TFE_TensorHandleResolve(handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public EagerTensor(string value, string device_name) : base(value) | public EagerTensor(string value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public EagerTensor(int value, string device_name) : base(value) | public EagerTensor(int value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public EagerTensor(float[] value, string device_name) : base(value) | public EagerTensor(float[] value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public EagerTensor(double[] value, string device_name) : base(value) | public EagerTensor(double[] value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public EagerTensor(NDArray value, string device_name) : base(value) | public EagerTensor(NDArray value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public override string ToString() | public override string ToString() | ||||
| @@ -102,14 +102,20 @@ namespace Tensorflow | |||||
| public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); | public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); | ||||
| /// <summary> | /// <summary> | ||||
| /// | |||||
| /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | |||||
| /// is for performance optimization by reusing an exiting unused op rather than | |||||
| /// creating a new op every time. If `raw_device_name` is `NULL` or empty, it | |||||
| /// does not set the device name. If it's not `NULL`, then it attempts to parse | |||||
| /// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster | |||||
| /// than separately calling it because if the existing op has the same | |||||
| /// `raw_device_name`, it skips parsing and just leave as it is. | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx">TFE_Context*</param> | |||||
| /// <param name="op_to_reset">TFE_Op*</param> | |||||
| /// <param name="op_or_function_name">const char*</param> | /// <param name="op_or_function_name">const char*</param> | ||||
| /// <param name="raw_device_name">const char*</param> | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <param name="op_to_reset">TFE_Op*</param> | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpReset(IntPtr ctx, string op_or_function_name, IntPtr status, IntPtr op_to_reset); | |||||
| public static extern void TFE_OpReset(IntPtr op_to_reset, string op_or_function_name, string raw_device_name, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -304,5 +310,17 @@ namespace Tensorflow | |||||
| /// <returns>TFE_Executor*</returns> | /// <returns>TFE_Executor*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); | public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_Test(); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor, int tensor_id); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_TapeGradient(IntPtr tape, long[] targetTensorIds, IntPtr[] target, long[] sourcesTensorIds, IntPtr status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,15 @@ | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System; | |||||
| using static Tensorflow.OpDef.Types; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| /// <summary> | |||||
| /// python\eager\pywrap_tfe_src.cc | |||||
| /// </summary> | |||||
| public partial class wrap_tfe_src | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -110,7 +110,7 @@ namespace Tensorflow.Eager | |||||
| var maybe_op = ReleaseThreadLocalOp(); | var maybe_op = ReleaseThreadLocalOp(); | ||||
| if (maybe_op != IntPtr.Zero) | if (maybe_op != IntPtr.Zero) | ||||
| { | { | ||||
| c_api.TFE_OpReset(ctx, op_or_function_name, status, maybe_op); | |||||
| c_api.TFE_OpReset(maybe_op, op_or_function_name, ctx.device_name, status); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -23,7 +23,6 @@ namespace Tensorflow.Gradients | |||||
| bool _watch_accessed_variables; | bool _watch_accessed_variables; | ||||
| bool _created_eagerly; | bool _created_eagerly; | ||||
| Tape _tape; | Tape _tape; | ||||
| int tape_nesting_id_counter = 0; | |||||
| public GradientActor(bool persistent = false, | public GradientActor(bool persistent = false, | ||||
| bool watch_accessed_variables = true) | bool watch_accessed_variables = true) | ||||
| @@ -41,18 +40,28 @@ namespace Tensorflow.Gradients | |||||
| "re-enter an already-active tape."); | "re-enter an already-active tape."); | ||||
| if (_tape == null) | if (_tape == null) | ||||
| { | |||||
| _tape = new Tape(); | |||||
| _tape.tape = new GradientTape(_persistent, _watch_accessed_variables); | |||||
| _tape.nesting_id = tape_nesting_id_counter++; | |||||
| } | |||||
| _tape = new Tape(_persistent, _watch_accessed_variables); | |||||
| else | |||||
| throw new NotImplementedException(""); | |||||
| _recording = true; | _recording = true; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Marks this tensor to be watched by the given tape. | |||||
| /// </summary> | |||||
| /// <param name="x"></param> | |||||
| public void watch(Tensor x) | public void watch(Tensor x) | ||||
| { | { | ||||
| _tape.watch(x); | |||||
| } | |||||
| public Tensor gradient(Tensor target, Tensor sources) | |||||
| { | |||||
| c_api.TFE_Test(); | |||||
| //using (var status = new Status()) | |||||
| //c_api.TFE_TapeGradient(_tape, new long[] { target.Id }, status); | |||||
| return null; | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| @@ -4,11 +4,21 @@ using System.Text; | |||||
| namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
| { | { | ||||
| public class Tape | |||||
| public class Tape : DisposableObject | |||||
| { | { | ||||
| public GradientTape tape { get; set; } | public GradientTape tape { get; set; } | ||||
| public int nesting_id { get; set; } | public int nesting_id { get; set; } | ||||
| public Tape(bool persistent, bool watch_accessed_variables) | |||||
| { | |||||
| _handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables); | |||||
| } | |||||
| public void watch(Tensor x) | |||||
| { | |||||
| c_api.TFE_TapeWatch(_handle, x, x.Id); | |||||
| } | |||||
| public static bool IsDtypeTrainable(DataType dtype) | public static bool IsDtypeTrainable(DataType dtype) | ||||
| { | { | ||||
| switch (dtype) | switch (dtype) | ||||
| @@ -26,5 +36,12 @@ namespace Tensorflow.Gradients | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | |||||
| } | |||||
| public static implicit operator IntPtr(Tape tape) | |||||
| => tape._handle; | |||||
| } | } | ||||
| } | } | ||||
| @@ -39,7 +39,7 @@ namespace Tensorflow | |||||
| IPackable<Tensor>, | IPackable<Tensor>, | ||||
| ICanBeFlattened | ICanBeFlattened | ||||
| { | { | ||||
| private readonly int _id; | |||||
| protected int _id; | |||||
| private readonly Operation _op; | private readonly Operation _op; | ||||
| private readonly int _value_index; | private readonly int _value_index; | ||||
| private TF_Output? _tf_output; | private TF_Output? _tf_output; | ||||
| @@ -30,6 +30,8 @@ namespace Tensorflow | |||||
| public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y); | public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y); | ||||
| public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); | public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); | ||||
| public static Tensor operator *(ResourceVariable x, ResourceVariable y) => gen_math_ops.mul(x, y); | |||||
| public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y); | public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y); | ||||
| public static Tensor operator >(ResourceVariable x, Tensor y) => gen_math_ops.greater(x.value(), y); | public static Tensor operator >(ResourceVariable x, Tensor y) => gen_math_ops.greater(x.value(), y); | ||||
| @@ -44,7 +44,9 @@ We can't found official prebuild binaries for each platform since tensorflow 2.0 | |||||
| https://www.tensorflow.org/install/source_windows | https://www.tensorflow.org/install/source_windows | ||||
| Download [Bazel 0.29.1](https://github.com/bazelbuild/bazel/releases/tag/0.29.1) to build tensorflow2.x. We build customized binary to export c_api from this [fork](https://github.com/SciSharp/tensorflow). | |||||
| Download [Bazel 2.0.0](https://github.com/bazelbuild/bazel/releases/tag/2.0.0) to build tensorflow2.x. We build customized binary to export c_api from this [fork](https://github.com/SciSharp/tensorflow). | |||||
| Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC`. | |||||
| `pacman -S git patch unzip` | `pacman -S git patch unzip` | ||||
| @@ -29,7 +29,7 @@ | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="FluentAssertions" Version="5.10.2" /> | |||||
| <PackageReference Include="FluentAssertions" Version="5.10.3" /> | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" /> | ||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="2.1.0" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.0" /> | <PackageReference Include="MSTest.TestFramework" Version="2.1.0" /> | ||||