| @@ -36,7 +36,7 @@ namespace Tensorflow | |||||
| var handle = Marshal.AllocHGlobal(size); | var handle = Marshal.AllocHGlobal(size); | ||||
| int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | ||||
| var consumers = new TF_Input[num]; | var consumers = new TF_Input[num]; | ||||
| for(int i = 0; i < num; i++) | |||||
| for (int i = 0; i < num; i++) | |||||
| { | { | ||||
| consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | ||||
| } | } | ||||
| @@ -50,7 +50,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| var control_inputs = new Operation[NumControlInputs]; | var control_inputs = new Operation[NumControlInputs]; | ||||
| if(NumControlInputs > 0) | |||||
| if (NumControlInputs > 0) | |||||
| { | { | ||||
| IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs); | IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs); | ||||
| c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); | c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); | ||||
| @@ -70,7 +70,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| var control_outputs = new Operation[NumControlOutputs]; | var control_outputs = new Operation[NumControlOutputs]; | ||||
| if(NumControlOutputs > 0) | |||||
| if (NumControlOutputs > 0) | |||||
| { | { | ||||
| IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | ||||
| c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); | c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); | ||||
| @@ -89,7 +89,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if(_outputs == null) | |||||
| if (_outputs == null) | |||||
| { | { | ||||
| _outputs = new Tensor[NumOutputs]; | _outputs = new Tensor[NumOutputs]; | ||||
| @@ -106,7 +106,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if(_inputs == null) | |||||
| if (_inputs == null) | |||||
| { | { | ||||
| var retval = new Tensor[NumInputs]; | var retval = new Tensor[NumInputs]; | ||||
| @@ -124,6 +124,18 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private NodeDef _node_def; | |||||
| public NodeDef node_def | |||||
| { | |||||
| get | |||||
| { | |||||
| if(_node_def == null) | |||||
| _node_def = GetNodeDef(); | |||||
| return _node_def; | |||||
| } | |||||
| } | |||||
| public Operation(IntPtr handle) | public Operation(IntPtr handle) | ||||
| { | { | ||||
| if (handle == IntPtr.Zero) | if (handle == IntPtr.Zero) | ||||
| @@ -195,7 +207,7 @@ namespace Tensorflow | |||||
| return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | ||||
| } | } | ||||
| public NodeDef GetNodeDef() | |||||
| private NodeDef GetNodeDef() | |||||
| { | { | ||||
| using (var s = new Status()) | using (var s = new Status()) | ||||
| using (var buffer = new Buffer()) | using (var buffer = new Buffer()) | ||||
| @@ -36,6 +36,8 @@ namespace Tensorflow | |||||
| TF_UINT32 = 22, | TF_UINT32 = 22, | ||||
| TF_UINT64 = 23, | TF_UINT64 = 23, | ||||
| DtFloatRef = 101, // DT_FLOAT_REF | |||||
| DtDoubleRef = 102, // DT_DOUBLE_REF | DtDoubleRef = 102, // DT_DOUBLE_REF | ||||
| DtInt32Ref = 103, // DT_INT32_REF | |||||
| } | } | ||||
| } | } | ||||
| @@ -162,6 +162,7 @@ namespace Tensorflow | |||||
| this.op = op; | this.op = op; | ||||
| this.value_index = value_index; | this.value_index = value_index; | ||||
| this._dtype = dtype; | this._dtype = dtype; | ||||
| _id = ops.uid(); | |||||
| } | } | ||||
| public List<Operation> consumers() | public List<Operation> consumers() | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -69,14 +70,15 @@ namespace Tensorflow | |||||
| { | { | ||||
| } | } | ||||
| // Or get the initial value from a Tensor or Python object. | |||||
| else | else | ||||
| { | { | ||||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | ||||
| } | |||||
| var shape = _initial_value.shape; | |||||
| dtype = _initial_value.dtype; | |||||
| _variable = gen_state_ops.variable_v2(shape, dtype, name); | |||||
| var shape = _initial_value.shape; | |||||
| dtype = _initial_value.dtype; | |||||
| _variable = gen_state_ops.variable_v2(shape, dtype, name); | |||||
| } | |||||
| // Manually overrides the variable's shape with the initial value's. | // Manually overrides the variable's shape with the initial value's. | ||||
| if (validate_shape) | if (validate_shape) | ||||
| @@ -87,8 +89,9 @@ namespace Tensorflow | |||||
| // If 'initial_value' makes use of other variables, make sure we don't | // If 'initial_value' makes use of other variables, make sure we don't | ||||
| // have an issue if these other variables aren't initialized first by | // have an issue if these other variables aren't initialized first by | ||||
| // using their initialized_value() method. | // using their initialized_value() method. | ||||
| var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value); | |||||
| _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; | |||||
| _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; | |||||
| if (!String.IsNullOrEmpty(caching_device)) | if (!String.IsNullOrEmpty(caching_device)) | ||||
| { | { | ||||
| @@ -112,5 +115,51 @@ namespace Tensorflow | |||||
| { | { | ||||
| return _snapshot; | return _snapshot; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Attempt to guard against dependencies on uninitialized variables. | |||||
| /// </summary> | |||||
| /// <param name="initial_value"></param> | |||||
| private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value) | |||||
| { | |||||
| return _safe_initial_value_from_tensor(initial_value, new Dictionary<string, Operation>()); | |||||
| } | |||||
| /// <summary> | |||||
| /// Replace dependencies on variables with their initialized values. | |||||
| /// </summary> | |||||
| /// <param name="tensor">A `Tensor`. The tensor to replace.</param> | |||||
| /// <param name="op_cache">A dict mapping operation names to `Operation`s.</param> | |||||
| /// <returns>A `Tensor` compatible with `tensor`.</returns> | |||||
| private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary<string, Operation> op_cache) | |||||
| { | |||||
| var op = tensor.op; | |||||
| var new_op = op_cache.ContainsKey(op.Name) ? op_cache[op.Name] : null; | |||||
| if(new_op == null) | |||||
| { | |||||
| new_op = _safe_initial_value_from_op(op, op_cache); | |||||
| op_cache[op.Name] = new_op; | |||||
| } | |||||
| return new_op.outputs[tensor.value_index]; | |||||
| } | |||||
| private Operation _safe_initial_value_from_op(Operation op, Dictionary<string, Operation> op_cache) | |||||
| { | |||||
| var op_type = op.node_def.Op; | |||||
| switch (op_type) | |||||
| { | |||||
| case "IsVariableInitialized": | |||||
| case "VarIsInitializedOp": | |||||
| case "ReadVariableOp": | |||||
| return op; | |||||
| case "Variable": | |||||
| case "VariableV2": | |||||
| case "VarHandleOp": | |||||
| break; | |||||
| } | |||||
| // Recursively build initializer expressions for inputs. | |||||
| return op; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||||
| _execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name); | _execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name); | ||||
| return new Tensor(_op, 0, dtype); | |||||
| return _result[0]; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -130,7 +130,7 @@ namespace TensorFlowNET.UnitTest | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | EXPECT_EQ(TF_Code.TF_OK, s.Code); | ||||
| // Serialize to NodeDef. | // Serialize to NodeDef. | ||||
| var node_def = neg.GetNodeDef(); | |||||
| var node_def = neg.node_def; | |||||
| // Validate NodeDef is what we expect. | // Validate NodeDef is what we expect. | ||||
| ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); | ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); | ||||
| @@ -145,13 +145,13 @@ namespace TensorFlowNET.UnitTest | |||||
| // Look up some nodes by name. | // Look up some nodes by name. | ||||
| Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | ||||
| EXPECT_EQ(neg, neg2); | EXPECT_EQ(neg, neg2); | ||||
| var node_def2 = neg2.GetNodeDef(); | |||||
| var node_def2 = neg2.node_def; | |||||
| EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | ||||
| Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | ||||
| EXPECT_EQ(feed, feed2); | EXPECT_EQ(feed, feed2); | ||||
| node_def = feed.GetNodeDef(); | |||||
| node_def2 = feed2.GetNodeDef(); | |||||
| node_def = feed.node_def; | |||||
| node_def2 = feed2.node_def; | |||||
| EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | ||||
| // Test iterating through the nodes of a graph. | // Test iterating through the nodes of a graph. | ||||
| @@ -186,7 +186,7 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| node_def = oper.GetNodeDef(); | |||||
| node_def = oper.node_def; | |||||
| Assert.Fail($"Unexpected Node: {node_def.ToString()}"); | Assert.Fail($"Unexpected Node: {node_def.ToString()}"); | ||||
| } | } | ||||
| } | } | ||||