Browse Source

add api of AddControlInput

tags/v0.9
Oceania2018 6 years ago
parent
commit
d65c0d3ec7
9 changed files with 54 additions and 68 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  3. +5
    -6
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +17
    -0
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  5. +0
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  6. +4
    -5
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  7. +14
    -1
      tensorflowlib/README.md
  8. BIN
      tensorflowlib/runtimes/win-x64/native/tensorflow.dll
  9. +11
    -53
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

+ 1
- 1
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -288,6 +288,6 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_UpdateEdge(IntPtr graph, TF_Output new_src, TF_Input dst, IntPtr status);
public static extern void UpdateEdge(IntPtr graph, TF_Output new_src, TF_Input dst, IntPtr status);
}
}

+ 2
- 1
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -29,7 +29,8 @@ namespace Tensorflow

public void _add_control_input(Operation op)
{
c_api.TF_AddControlInput(_operDesc, op);
// c_api.TF_AddControlInput(_operDesc, op);
c_api.AddControlInput(graph, _handle, op);
}

public void _add_control_inputs(Operation[] ops)


+ 5
- 6
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -1,5 +1,4 @@
using Google.Protobuf.Collections;
using Newtonsoft.Json;
//using Newtonsoft.Json;
using System;
using System.Collections.Generic;
@@ -34,15 +33,15 @@ namespace Tensorflow
private readonly IntPtr _operDesc;

private Graph _graph;
[JsonIgnore]
//[JsonIgnore]
public Graph graph => _graph;
[JsonIgnore]
//[JsonIgnore]
public int _id => _id_value;
[JsonIgnore]
//[JsonIgnore]
public int _id_value;

public string type => OpType;
[JsonIgnore]
//[JsonIgnore]
public Operation op => this;
public TF_DataType dtype => TF_DataType.DtInvalid;
private Status status = new Status();
@@ -289,7 +288,7 @@ namespace Tensorflow
_inputs = null;
// after the c_api call next time _inputs is accessed
// the updated inputs are reloaded from the c_api
c_api.TF_UpdateEdge(_graph, output, input, status);
c_api.UpdateEdge(_graph, output, input, status);
//var updated_inputs = inputs;
}


+ 17
- 0
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -42,6 +42,23 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_AddControlInput(IntPtr desc, IntPtr input);

/// <summary>
///
/// </summary>
/// <param name="graph">TF_Graph*</param>
/// <param name="op">TF_Operation*</param>
/// <param name="input">TF_Operation*</param>
[DllImport(TensorFlowLibName)]
public static extern void AddControlInput(IntPtr graph, IntPtr op, IntPtr input);

/// <summary>
///
/// </summary>
/// <param name="graph">TF_Graph*</param>
/// <param name="op">TF_Operation*</param>
[DllImport(TensorFlowLibName)]
public static extern void RemoveAllControlInputs(IntPtr graph, IntPtr op);

/// <summary>
/// For inputs that take a list of tensors.
/// inputs must point to TF_Output[num_inputs].


+ 0
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -45,7 +45,6 @@ Bug memory leak issue when allocating Tensor.</PackageReleaseNotes>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.7.0" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.1" />
</ItemGroup>

<ItemGroup>


+ 4
- 5
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -1,5 +1,4 @@
//using Newtonsoft.Json;
using Newtonsoft.Json;
using NumSharp;
using System;
using System.Collections.Generic;
@@ -19,13 +18,13 @@ namespace Tensorflow
private readonly IntPtr _handle;

private int _id;
[JsonIgnore]
//[JsonIgnore]
public int Id => _id;
[JsonIgnore]
//[JsonIgnore]
public Graph graph => op?.graph;
[JsonIgnore]
//[JsonIgnore]
public Operation op { get; }
[JsonIgnore]
//[JsonIgnore]
public Tensor[] outputs => op.outputs;

/// <summary>


+ 14
- 1
tensorflowlib/README.md View File

@@ -36,4 +36,17 @@ pacman -S git patch unzip

4. Install from local wheel file.

`pip install C:/tmp/tensorflow_pkg/tensorflow-1.13.0-cp36-cp36m-win_amd64.whl`
`pip install C:/tmp/tensorflow_pkg/tensorflow-1.13.0-cp36-cp36m-win_amd64.whl`

### Export more APIs

Add more api to `c_api.h`

```c++
TF_CAPI_EXPORT extern void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
TF_CAPI_EXPORT extern void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status);
TF_CAPI_EXPORT extern void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op);
```




BIN
tensorflowlib/runtimes/win-x64/native/tensorflow.dll View File


+ 11
- 53
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -15,79 +15,37 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
public void testCondTrue()
{
var graph = tf.Graph().as_default();
// tf.train.import_meta_graph("cond_test.meta");
var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
with(tf.Session(graph), sess =>
{
var x = tf.constant(2, name: "x"); // graph.get_operation_by_name("Const").output;
var y = tf.constant(5, name: "y"); // graph.get_operation_by_name("Const_1").output;
var pred = tf.less(x, y); // graph.get_operation_by_name("Less").output;
Func<ITensorOrOperation> if_true = delegate
{
return tf.constant(2, name: "t2");
};
Func<ITensorOrOperation> if_false = delegate
{
return tf.constant(5, name: "f5");
};
var x = tf.constant(2, name: "x");
var y = tf.constant(5, name: "y");
var z = control_flow_ops.cond(pred, if_true, if_false); // graph.get_operation_by_name("cond/Merge").output
var z = control_flow_ops.cond(tf.less(x, y),
() => tf.constant(22, name: "t2"),
() => tf.constant(55, name: "f5"));
json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
int result = z.eval(sess);
assertEquals(result, 2);
assertEquals(result, 22);
});
}
[TestMethod]
public void testCondFalse()
{
/* python
* import tensorflow as tf
from tensorflow.python.framework import ops
def if_true():
return tf.math.multiply(x, 17)
def if_false():
return tf.math.add(y, 23)
with tf.Session() as sess:
x = tf.constant(2)
y = tf.constant(1)
pred = tf.math.less(x,y)
z = tf.cond(pred, if_true, if_false)
result = z.eval()
print(result == 24) */
var graph = tf.Graph().as_default();
//tf.train.import_meta_graph("cond_test.meta");
//var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
with(tf.Session(), sess =>
with(tf.Session(graph), sess =>
{
var x = tf.constant(2, name: "x");
var y = tf.constant(1, name: "y");
var pred = tf.less(x, y);
Func<ITensorOrOperation> if_true = delegate
{
return tf.constant(2, name: "t2");
};
Func<ITensorOrOperation> if_false = delegate
{
return tf.constant(1, name: "f1");
};
var z = control_flow_ops.cond(pred, if_true, if_false);
var z = control_flow_ops.cond(tf.less(x, y),
() => tf.constant(22, name: "t2"),
() => tf.constant(11, name: "f1"));
var json1 = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
int result = z.eval(sess);
assertEquals(result, 1);
assertEquals(result, 11);
});
}


Loading…
Cancel
Save