Browse Source

node_def property in Operation #134

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
444cc42058
6 changed files with 81 additions and 17 deletions
  1. +18
    -6
      src/TensorFlowNET.Core/Operations/Operation.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Tensors/TF_DataType.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  4. +54
    -5
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
  6. +5
    -5
      test/TensorFlowNET.UnitTest/GraphTest.cs

+ 18
- 6
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow
var handle = Marshal.AllocHGlobal(size); var handle = Marshal.AllocHGlobal(size);
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
var consumers = new TF_Input[num]; var consumers = new TF_Input[num];
for(int i = 0; i < num; i++)
for (int i = 0; i < num; i++)
{ {
consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size); consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
} }
@@ -50,7 +50,7 @@ namespace Tensorflow
{ {
var control_inputs = new Operation[NumControlInputs]; var control_inputs = new Operation[NumControlInputs];


if(NumControlInputs > 0)
if (NumControlInputs > 0)
{ {
IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs); IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs);
c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs);
@@ -70,7 +70,7 @@ namespace Tensorflow
{ {
var control_outputs = new Operation[NumControlOutputs]; var control_outputs = new Operation[NumControlOutputs];


if(NumControlOutputs > 0)
if (NumControlOutputs > 0)
{ {
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs);
@@ -89,7 +89,7 @@ namespace Tensorflow
{ {
get get
{ {
if(_outputs == null)
if (_outputs == null)
{ {
_outputs = new Tensor[NumOutputs]; _outputs = new Tensor[NumOutputs];


@@ -106,7 +106,7 @@ namespace Tensorflow
{ {
get get
{ {
if(_inputs == null)
if (_inputs == null)
{ {
var retval = new Tensor[NumInputs]; var retval = new Tensor[NumInputs];


@@ -124,6 +124,18 @@ namespace Tensorflow
} }
} }


private NodeDef _node_def;
public NodeDef node_def
{
get
{
if(_node_def == null)
_node_def = GetNodeDef();

return _node_def;
}
}

public Operation(IntPtr handle) public Operation(IntPtr handle)
{ {
if (handle == IntPtr.Zero) if (handle == IntPtr.Zero)
@@ -195,7 +207,7 @@ namespace Tensorflow
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
} }


public NodeDef GetNodeDef()
private NodeDef GetNodeDef()
{ {
using (var s = new Status()) using (var s = new Status())
using (var buffer = new Buffer()) using (var buffer = new Buffer())


+ 2
- 0
src/TensorFlowNET.Core/Tensors/TF_DataType.cs View File

@@ -36,6 +36,8 @@ namespace Tensorflow
TF_UINT32 = 22, TF_UINT32 = 22,
TF_UINT64 = 23, TF_UINT64 = 23,


DtFloatRef = 101, // DT_FLOAT_REF
DtDoubleRef = 102, // DT_DOUBLE_REF DtDoubleRef = 102, // DT_DOUBLE_REF
DtInt32Ref = 103, // DT_INT32_REF
} }
} }

+ 1
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -162,6 +162,7 @@ namespace Tensorflow
this.op = op; this.op = op;
this.value_index = value_index; this.value_index = value_index;
this._dtype = dtype; this._dtype = dtype;
_id = ops.uid();
} }


public List<Operation> consumers() public List<Operation> consumers()


+ 54
- 5
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;


namespace Tensorflow namespace Tensorflow
@@ -69,14 +70,15 @@ namespace Tensorflow
{ {


} }
// Or get the initial value from a Tensor or Python object.
else else
{ {
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
}


var shape = _initial_value.shape;
dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype, name);
var shape = _initial_value.shape;
dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype, name);
}


// Manually overrides the variable's shape with the initial value's. // Manually overrides the variable's shape with the initial value's.
if (validate_shape) if (validate_shape)
@@ -87,8 +89,9 @@ namespace Tensorflow
// If 'initial_value' makes use of other variables, make sure we don't // If 'initial_value' makes use of other variables, make sure we don't
// have an issue if these other variables aren't initialized first by // have an issue if these other variables aren't initialized first by
// using their initialized_value() method. // using their initialized_value() method.
var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value);


_initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op;
_initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;


if (!String.IsNullOrEmpty(caching_device)) if (!String.IsNullOrEmpty(caching_device))
{ {
@@ -112,5 +115,51 @@ namespace Tensorflow
{ {
return _snapshot; return _snapshot;
} }

/// <summary>
/// Attempt to guard against dependencies on uninitialized variables.
/// </summary>
/// <param name="initial_value"></param>
private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value)
{
return _safe_initial_value_from_tensor(initial_value, new Dictionary<string, Operation>());
}

/// <summary>
/// Replace dependencies on variables with their initialized values.
/// </summary>
/// <param name="tensor">A `Tensor`. The tensor to replace.</param>
/// <param name="op_cache">A dict mapping operation names to `Operation`s.</param>
/// <returns>A `Tensor` compatible with `tensor`.</returns>
private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary<string, Operation> op_cache)
{
var op = tensor.op;
var new_op = op_cache.ContainsKey(op.Name) ? op_cache[op.Name] : null;
if(new_op == null)
{
new_op = _safe_initial_value_from_op(op, op_cache);
op_cache[op.Name] = new_op;
}
return new_op.outputs[tensor.value_index];
}

private Operation _safe_initial_value_from_op(Operation op, Dictionary<string, Operation> op_cache)
{
var op_type = op.node_def.Op;
switch (op_type)
{
case "IsVariableInitialized":
case "VarIsInitializedOp":
case "ReadVariableOp":
return op;
case "Variable":
case "VariableV2":
case "VarHandleOp":
break;
}

// Recursively build initializer expressions for inputs.
return op;
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow


_execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name); _execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name);


return new Tensor(_op, 0, dtype);
return _result[0];
} }


/// <summary> /// <summary>


+ 5
- 5
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -130,7 +130,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);


// Serialize to NodeDef. // Serialize to NodeDef.
var node_def = neg.GetNodeDef();
var node_def = neg.node_def;


// Validate NodeDef is what we expect. // Validate NodeDef is what we expect.
ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); ASSERT_TRUE(c_test_util.IsNeg(node_def, "add"));
@@ -145,13 +145,13 @@ namespace TensorFlowNET.UnitTest
// Look up some nodes by name. // Look up some nodes by name.
Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
EXPECT_EQ(neg, neg2); EXPECT_EQ(neg, neg2);
var node_def2 = neg2.GetNodeDef();
var node_def2 = neg2.node_def;
EXPECT_EQ(node_def.ToString(), node_def2.ToString()); EXPECT_EQ(node_def.ToString(), node_def2.ToString());


Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
EXPECT_EQ(feed, feed2); EXPECT_EQ(feed, feed2);
node_def = feed.GetNodeDef();
node_def2 = feed2.GetNodeDef();
node_def = feed.node_def;
node_def2 = feed2.node_def;
EXPECT_EQ(node_def.ToString(), node_def2.ToString()); EXPECT_EQ(node_def.ToString(), node_def2.ToString());


// Test iterating through the nodes of a graph. // Test iterating through the nodes of a graph.
@@ -186,7 +186,7 @@ namespace TensorFlowNET.UnitTest
} }
else else
{ {
node_def = oper.GetNodeDef();
node_def = oper.node_def;
Assert.Fail($"Unexpected Node: {node_def.ToString()}"); Assert.Fail($"Unexpected Node: {node_def.ToString()}");
} }
} }


Loading…
Cancel
Save