Browse Source

Merge pull request #210 from henon/master

Ported a test case from ops_test.py and fixed a bug that was revealed by it
tags/v0.9
Haiping GitHub 6 years ago
parent
commit
3cb6c237a3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 111 additions and 8 deletions
  1. +7
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  2. +27
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +18
    -1
      src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs
  4. +18
    -3
      src/TensorFlowNET.Core/ops.py.cs
  5. +1
    -1
      src/TensorFlowNET.Visualization/TensorFlowNET.Visualization.csproj
  6. +1
    -1
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  7. +38
    -0
      test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs
  8. +1
    -1
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

+ 7
- 0
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -54,6 +54,13 @@ namespace Tensorflow
return ret;
}

/// <summary>
/// Returns a context manager that specifies control dependencies.
///
/// Use with the `with` keyword to specify that all operations constructed
/// within the context should have control dependencies on
/// `control_inputs`.
/// </summary>
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
{
if (control_inputs == null)


+ 27
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -7,7 +7,26 @@ using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
{
/// <summary>
/// Represents a graph node that performs computation on tensors.
///
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
/// more `Tensor` objects as input, and produces zero or more `Tensor`
/// objects as output. Objects of type `Operation` are created by
/// calling an op constructor(such as `tf.matmul`)
/// or `tf.Graph.create_op`.
///
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
/// as output.
///
/// After the graph has been launched in a session, an `Operation` can
/// be executed by passing it to
/// `tf.Session.run`.
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
/// </summary>
public partial class Operation : ITensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
@@ -98,6 +117,13 @@ namespace Tensorflow
case Operation c1:
control_input_ops.Add(c1);
break;
case Tensor tensor:
control_input_ops.Add(tensor.op);
break;
// TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
//case IndexedSlices islices:
// control_input_ops.Add(islices.op);
// break;
default:
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
}


+ 18
- 1
src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs View File

@@ -4,8 +4,25 @@ using System.Text;

namespace Tensorflow.Train
{
/// <summary>
/// Optimizer that implements the gradient descent algorithm.
/// </summary>
public class GradientDescentOptimizer : Optimizer
{
{
/// <summary>
/// Construct a new gradient descent optimizer.
/// </summary>
/// <param name="learning_rate">A Tensor or a floating point value. The learning
/// rate to use.</param>
/// <param name="use_locking">If true use locks for update operations.</param>
/// <param name="name">Optional name prefix for the operations created when applying
/// gradients.Defaults to "GradientDescent".</param>
/// <remarks>
/// When eager execution is enabled, `learning_rate` can be a callable that
/// takes no arguments and returns the actual value to use.This can be useful
/// for changing these values across different invocations of optimizer
/// functions.
/// </remarks>
public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent")
: base(learning_rate, use_locking, name)
{


+ 18
- 3
src/TensorFlowNET.Core/ops.py.cs View File

@@ -98,12 +98,27 @@ namespace Tensorflow
public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
{
return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref);
}
}
/// <summary>
/// Wrapper for `Graph.control_dependencies()` using the default graph.
///
/// See `tf.Graph.control_dependencies` for more details.

/// When eager execution is enabled, any callable object in the `control_inputs`
/// list will be called.
/// </summary>
/// <param name="control_inputs"></param>
/// <param name="control_inputs">
/// A list of `Operation` or `Tensor` objects which
/// must be executed or computed before running the operations
/// defined in the context.Can also be `None` to clear the control
/// dependencies.If eager execution is enabled, any callable object in the
/// `control_inputs` list will be called.
/// </param>
/// <returns>
/// A context manager that specifies control dependencies for all
/// operations constructed within the context.
/// </returns>
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
{
return get_default_graph().control_dependencies(control_inputs);


+ 1
- 1
src/TensorFlowNET.Visualization/TensorFlowNET.Visualization.csproj View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk.Web">

<PropertyGroup>
<TargetFramework>netcoreapp2.2</TargetFramework>
<TargetFramework>netcoreapp2.1</TargetFramework>
<AspNetCoreHostingModel>InProcess</AspNetCoreHostingModel>
</PropertyGroup>



+ 1
- 1
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.2</TargetFramework>
<TargetFramework>netcoreapp2.1</TargetFramework>
<GeneratePackageOnBuild>false</GeneratePackageOnBuild>
</PropertyGroup>



+ 38
- 0
test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs View File

@@ -0,0 +1,38 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
namespace TensorFlowNET.UnitTest
{
/// <summary>
/// tensorflow/python/framework/ops_test.py
/// </summary>
[TestClass]
public class ControlDependenciesTest : Python
{
[TestMethod]
public void TestBasic()
{
var graph = tf.Graph().as_default();
Tensor a=null, b = null, c = null, d = null, e = null;
with<Graph>(graph, g =>
{
a = constant_op.constant(1.0);
b = constant_op.constant(1.0);
with(g.control_dependencies(new ITensorOrOperation[] {a}), x =>
{
c = constant_op.constant(1.0);
d = array_ops.identity(b);
e = array_ops.identity(c);
});
});
Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] {a.op}));
Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] {a.op}));
// e should be dominated by c.
Assert.AreEqual(0, e.op.control_inputs.Length);
}
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netcoreapp2.2</TargetFramework>
<TargetFramework>netcoreapp2.1</TargetFramework>

<IsPackable>false</IsPackable>
</PropertyGroup>


Loading…
Cancel
Save