Browse Source

GradientActor

tags/v0.20
Oceania2018 5 years ago
parent
commit
2f3bd61b1b
10 changed files with 126 additions and 25 deletions
  1. +4
    -1
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +10
    -6
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  3. +18
    -3
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  4. +7
    -4
      src/TensorFlowNET.Core/Gradients/GradientActor.cs
  5. +3
    -2
      src/TensorFlowNET.Core/Gradients/Tape.cs
  6. +77
    -2
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  7. +1
    -1
      src/TensorFlowNET.Core/TensorFlow.Binding.csproj
  8. +1
    -1
      src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
  9. +2
    -2
      test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj
  10. +3
    -3
      test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj

+ 4
- 1
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Eager;
using Tensorflow.Operations;

namespace Tensorflow
@@ -259,7 +260,6 @@ namespace Tensorflow
public Tensor sub<Tx, Ty>(Tx a, Ty b, string name = null)
=> gen_math_ops.sub(a, b, name: name);


public Tensor divide(Tensor a, Tensor b)
=> a / b;

@@ -348,6 +348,9 @@ namespace Tensorflow
public Tensor minimum<T1, T2>(T1 x, T2 y, string name = null)
=> gen_math_ops.minimum(x, y, name: name);

public Tensor multiply(Tensor x, Tensor y, string name = null)
=> gen_math_ops.mul(x, y, name: name);

/// <summary>
/// return x * y
/// </summary>


+ 10
- 6
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
@@ -9,41 +10,44 @@ namespace Tensorflow.Eager
{
Status status = new Status();
TFE_TensorHandle tfe_tensor_handle;
public IntPtr EagerTensorHandle { get; set; }

public EagerTensor(IntPtr handle) : base(handle)
{
tfe_tensor_handle = handle;
_handle = c_api.TFE_TensorHandleResolve(handle, status);
_id = ops.uid();
}

public EagerTensor(string value, string device_name) : base(value)
{
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
_id = ops.uid();
}

public EagerTensor(int value, string device_name) : base(value)
{
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
_id = ops.uid();
EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle);
}

public EagerTensor(float value, string device_name) : base(value)
{
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle);
}

public EagerTensor(float[] value, string device_name) : base(value)
{
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
_id = ops.uid();
}

public EagerTensor(double[] value, string device_name) : base(value)
{
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
_id = ops.uid();
}

public EagerTensor(NDArray value, string device_name) : base(value)
{
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
_id = ops.uid();
}

public override string ToString()


+ 18
- 3
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -7,6 +7,12 @@ namespace Tensorflow
{
public partial class c_api
{
[DllImport(TensorFlowLibName)]
public static extern void TFE_RegisterGradientFunction(_gradient_function_callback callbackPointer);

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate void _gradient_function_callback(string op_name, int num_inputs, IntPtr attrs, int num_attrs);

/// <summary>
/// Return a new options object.
/// </summary>
@@ -186,6 +192,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern TFE_TensorHandle TFE_NewTensorHandle(IntPtr t, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern TFE_TensorHandle TFE_EagerTensorFromHandle(IntPtr ctx, IntPtr h);

/// <summary>
/// Sets the default execution mode (sync/async). Note that this can be
/// overridden per thread using TFE_ContextSetExecutorForThread.
@@ -312,15 +321,21 @@ namespace Tensorflow
public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx);

[DllImport(TensorFlowLibName)]
public static extern void TFE_Test();
public static extern IntPtr TFE_FastPathExecute(IntPtr ctx,
string device_name,
string op_name,
string name,
IntPtr[] args,
int input_size,
IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables);

[DllImport(TensorFlowLibName)]
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor, int tensor_id);
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor);

[DllImport(TensorFlowLibName)]
public static extern void TFE_TapeGradient(IntPtr tape, long[] targetTensorIds, IntPtr[] target, long[] sourcesTensorIds, IntPtr status);
public static extern void TFE_TapeGradient(IntPtr tape, IntPtr[] target, IntPtr sources, IntPtr status);
}
}

+ 7
- 4
src/TensorFlowNET.Core/Gradients/GradientActor.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow.Gradients
@@ -53,14 +54,16 @@ namespace Tensorflow.Gradients
/// <param name="x"></param>
public void watch(Tensor x)
{
_tape.watch(x);
_tape.watch(x as EagerTensor);
}

public Tensor gradient(Tensor target, Tensor sources)
{
c_api.TFE_Test();
//using (var status = new Status())
//c_api.TFE_TapeGradient(_tape, new long[] { target.Id }, status);
using (var status = new Status())
{
c_api.TFE_TapeGradient(_tape, new IntPtr[] { target }, IntPtr.Zero, status);
}
return null;
}



+ 3
- 2
src/TensorFlowNET.Core/Gradients/Tape.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;

namespace Tensorflow.Gradients
{
@@ -14,9 +15,9 @@ namespace Tensorflow.Gradients
_handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables);
}

public void watch(Tensor x)
public void watch(EagerTensor x)
{
c_api.TFE_TapeWatch(_handle, x, x.Id);
c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle);
}

public static bool IsDtypeTrainable(DataType dtype)


+ 77
- 2
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -192,6 +192,28 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor add(Tensor x, Tensor y, string name = null)
{
if (tf.context.executing_eagerly())
{
using (var status = new Status())
{
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Add", name, new IntPtr[]
{
(x as EagerTensor).EagerTensorHandle,
(y as EagerTensor).EagerTensorHandle
}, 2, status);
status.Check(true);
return new EagerTensor(_result);
}
}

var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y });

return _op.output;
}

public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null)
{
if (tf.context.executing_eagerly())
@@ -593,6 +615,28 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor sub(Tensor x, Tensor y, string name = null)
{
if (tf.context.executing_eagerly())
{
using (var status = new Status())
{
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Sub", name, new IntPtr[]
{
(x as EagerTensor).EagerTensorHandle,
(y as EagerTensor).EagerTensorHandle
}, 2, status);
status.Check(true);
return new EagerTensor(_result);
}
}

var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y });

return _op.output;
}

public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null)
{
if (tf.context.executing_eagerly())
@@ -667,6 +711,28 @@ namespace Tensorflow
return _op.output;
}

public static Tensor mul(Tensor x, Tensor y, string name = null)
{
if (tf.context.executing_eagerly())
{
using (var status = new Status())
{
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Mul", name, new IntPtr[]
{
(x as EagerTensor).EagerTensorHandle,
(y as EagerTensor).EagerTensorHandle
}, 2, status);
status.Check(true);
return new EagerTensor(_result);
}
}

var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });

return _op.output;
}

public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null)
{
if (tf.context.executing_eagerly())
@@ -693,8 +759,17 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "RealDiv", name, null, x, y);
return _result;
using (var status = new Status())
{
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"RealDiv", name, new IntPtr[]
{
(x as EagerTensor).EagerTensorHandle,
(y as EagerTensor).EagerTensorHandle
}, 2, status);
status.Check(true);
return new EagerTensor(_result);
}
}

var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y });


+ 1
- 1
src/TensorFlowNET.Core/TensorFlow.Binding.csproj View File

@@ -4,7 +4,7 @@
<TargetFramework>netstandard2.0</TargetFramework>
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.01.0</TargetTensorFlow>
<TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.20.0</Version>
<LangVersion>8.0</LangVersion>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>


+ 1
- 1
src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj View File

@@ -18,7 +18,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="BenchmarkDotNet" Version="0.12.0" />
<PackageReference Include="BenchmarkDotNet" Version="0.12.1" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.1.0" />
<PackageReference Include="TensorFlow.NET" Version="0.15.1" />
</ItemGroup>


+ 2
- 2
test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj View File

@@ -31,8 +31,8 @@
<ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" />
<PackageReference Include="MSTest.TestAdapter" Version="2.1.0" />
<PackageReference Include="MSTest.TestFramework" Version="2.1.0" />
<PackageReference Include="MSTest.TestAdapter" Version="2.1.1" />
<PackageReference Include="MSTest.TestFramework" Version="2.1.1" />
<PackageReference Include="NumSharp.Lite" Version="0.1.7" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.1.0" />
</ItemGroup>


+ 3
- 3
test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj View File

@@ -8,9 +8,9 @@

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" />
<PackageReference Include="MSTest.TestAdapter" Version="2.1.0" />
<PackageReference Include="MSTest.TestFramework" Version="2.1.0" />
<PackageReference Include="coverlet.collector" Version="1.2.0">
<PackageReference Include="MSTest.TestAdapter" Version="2.1.1" />
<PackageReference Include="MSTest.TestFramework" Version="2.1.1" />
<PackageReference Include="coverlet.collector" Version="1.2.1">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>


Loading…
Cancel
Save