From 444cc420580a66d004bb728db1498f5483b80be1 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Fri, 25 Jan 2019 15:10:15 -0600 Subject: [PATCH] node_def property in Operation #134 --- .../Operations/Operation.cs | 24 ++++++-- src/TensorFlowNET.Core/Tensors/TF_DataType.cs | 2 + src/TensorFlowNET.Core/Tensors/Tensor.cs | 1 + .../Variables/RefVariable.cs | 59 +++++++++++++++++-- .../Variables/gen_state_ops.py.cs | 2 +- test/TensorFlowNET.UnitTest/GraphTest.cs | 10 ++-- 6 files changed, 81 insertions(+), 17 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index f1306fc0..7189747d 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -36,7 +36,7 @@ namespace Tensorflow var handle = Marshal.AllocHGlobal(size); int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); var consumers = new TF_Input[num]; - for(int i = 0; i < num; i++) + for (int i = 0; i < num; i++) { consumers[i] = Marshal.PtrToStructure(handle + i * size); } @@ -50,7 +50,7 @@ namespace Tensorflow { var control_inputs = new Operation[NumControlInputs]; - if(NumControlInputs > 0) + if (NumControlInputs > 0) { IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlInputs); c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); @@ -70,7 +70,7 @@ namespace Tensorflow { var control_outputs = new Operation[NumControlOutputs]; - if(NumControlOutputs > 0) + if (NumControlOutputs > 0) { IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlOutputs); c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); @@ -89,7 +89,7 @@ namespace Tensorflow { get { - if(_outputs == null) + if (_outputs == null) { _outputs = new Tensor[NumOutputs]; @@ -106,7 +106,7 @@ namespace Tensorflow { get { - if(_inputs == null) + if (_inputs == null) { var retval = new Tensor[NumInputs]; @@ -124,6 +124,18 @@ namespace Tensorflow } } + private NodeDef _node_def; + public NodeDef node_def + { + get + { + if(_node_def == null) + _node_def = GetNodeDef(); + + return _node_def; + } + } + public Operation(IntPtr handle) { if (handle == IntPtr.Zero) @@ -195,7 +207,7 @@ namespace Tensorflow return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); } - public NodeDef GetNodeDef() + private NodeDef GetNodeDef() { using (var s = new Status()) using (var buffer = new Buffer()) diff --git a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs index b3e5b79b..308f98f1 100644 --- a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs +++ b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs @@ -36,6 +36,8 @@ namespace Tensorflow TF_UINT32 = 22, TF_UINT64 = 23, + DtFloatRef = 101, // DT_FLOAT_REF DtDoubleRef = 102, // DT_DOUBLE_REF + DtInt32Ref = 103, // DT_INT32_REF } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index cf72bb23..24542f4f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -162,6 +162,7 @@ namespace Tensorflow this.op = op; this.value_index = value_index; this._dtype = dtype; + _id = ops.uid(); } public List consumers() diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index a864c2cd..30f49ebd 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow @@ -69,14 +70,15 @@ namespace Tensorflow { } + // Or get the initial value from a Tensor or Python object. else { _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); - } - var shape = _initial_value.shape; - dtype = _initial_value.dtype; - _variable = gen_state_ops.variable_v2(shape, dtype, name); + var shape = _initial_value.shape; + dtype = _initial_value.dtype; + _variable = gen_state_ops.variable_v2(shape, dtype, name); + } // Manually overrides the variable's shape with the initial value's. if (validate_shape) @@ -87,8 +89,9 @@ namespace Tensorflow // If 'initial_value' makes use of other variables, make sure we don't // have an issue if these other variables aren't initialized first by // using their initialized_value() method. + var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value); - _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; + _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; if (!String.IsNullOrEmpty(caching_device)) { @@ -112,5 +115,51 @@ namespace Tensorflow { return _snapshot; } + + /// + /// Attempt to guard against dependencies on uninitialized variables. + /// + /// + private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value) + { + return _safe_initial_value_from_tensor(initial_value, new Dictionary()); + } + + /// + /// Replace dependencies on variables with their initialized values. + /// + /// A `Tensor`. The tensor to replace. + /// A dict mapping operation names to `Operation`s. + /// A `Tensor` compatible with `tensor`. + private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary op_cache) + { + var op = tensor.op; + var new_op = op_cache.ContainsKey(op.Name) ? op_cache[op.Name] : null; + if(new_op == null) + { + new_op = _safe_initial_value_from_op(op, op_cache); + op_cache[op.Name] = new_op; + } + return new_op.outputs[tensor.value_index]; + } + + private Operation _safe_initial_value_from_op(Operation op, Dictionary op_cache) + { + var op_type = op.node_def.Op; + switch (op_type) + { + case "IsVariableInitialized": + case "VarIsInitializedOp": + case "ReadVariableOp": + return op; + case "Variable": + case "VariableV2": + case "VarHandleOp": + break; + } + + // Recursively build initializer expressions for inputs. + return op; + } } } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index aa6170e3..35e0d1c5 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -42,7 +42,7 @@ namespace Tensorflow _execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name); - return new Tensor(_op, 0, dtype); + return _result[0]; } /// diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 57963660..652378bf 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -130,7 +130,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(TF_Code.TF_OK, s.Code); // Serialize to NodeDef. - var node_def = neg.GetNodeDef(); + var node_def = neg.node_def; // Validate NodeDef is what we expect. ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); @@ -145,13 +145,13 @@ namespace TensorFlowNET.UnitTest // Look up some nodes by name. Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); EXPECT_EQ(neg, neg2); - var node_def2 = neg2.GetNodeDef(); + var node_def2 = neg2.node_def; EXPECT_EQ(node_def.ToString(), node_def2.ToString()); Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); EXPECT_EQ(feed, feed2); - node_def = feed.GetNodeDef(); - node_def2 = feed2.GetNodeDef(); + node_def = feed.node_def; + node_def2 = feed2.node_def; EXPECT_EQ(node_def.ToString(), node_def2.ToString()); // Test iterating through the nodes of a graph. @@ -186,7 +186,7 @@ namespace TensorFlowNET.UnitTest } else { - node_def = oper.GetNodeDef(); + node_def = oper.node_def; Assert.Fail($"Unexpected Node: {node_def.ToString()}"); } }