| @@ -36,7 +36,7 @@ namespace Tensorflow | |||
| var handle = Marshal.AllocHGlobal(size); | |||
| int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | |||
| var consumers = new TF_Input[num]; | |||
| for(int i = 0; i < num; i++) | |||
| for (int i = 0; i < num; i++) | |||
| { | |||
| consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||
| } | |||
| @@ -50,7 +50,7 @@ namespace Tensorflow | |||
| { | |||
| var control_inputs = new Operation[NumControlInputs]; | |||
| if(NumControlInputs > 0) | |||
| if (NumControlInputs > 0) | |||
| { | |||
| IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs); | |||
| c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); | |||
| @@ -70,7 +70,7 @@ namespace Tensorflow | |||
| { | |||
| var control_outputs = new Operation[NumControlOutputs]; | |||
| if(NumControlOutputs > 0) | |||
| if (NumControlOutputs > 0) | |||
| { | |||
| IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | |||
| c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); | |||
| @@ -89,7 +89,7 @@ namespace Tensorflow | |||
| { | |||
| get | |||
| { | |||
| if(_outputs == null) | |||
| if (_outputs == null) | |||
| { | |||
| _outputs = new Tensor[NumOutputs]; | |||
| @@ -106,7 +106,7 @@ namespace Tensorflow | |||
| { | |||
| get | |||
| { | |||
| if(_inputs == null) | |||
| if (_inputs == null) | |||
| { | |||
| var retval = new Tensor[NumInputs]; | |||
| @@ -124,6 +124,18 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private NodeDef _node_def; | |||
| public NodeDef node_def | |||
| { | |||
| get | |||
| { | |||
| if(_node_def == null) | |||
| _node_def = GetNodeDef(); | |||
| return _node_def; | |||
| } | |||
| } | |||
| public Operation(IntPtr handle) | |||
| { | |||
| if (handle == IntPtr.Zero) | |||
| @@ -195,7 +207,7 @@ namespace Tensorflow | |||
| return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||
| } | |||
| public NodeDef GetNodeDef() | |||
| private NodeDef GetNodeDef() | |||
| { | |||
| using (var s = new Status()) | |||
| using (var buffer = new Buffer()) | |||
| @@ -36,6 +36,8 @@ namespace Tensorflow | |||
| TF_UINT32 = 22, | |||
| TF_UINT64 = 23, | |||
| DtFloatRef = 101, // DT_FLOAT_REF | |||
| DtDoubleRef = 102, // DT_DOUBLE_REF | |||
| DtInt32Ref = 103, // DT_INT32_REF | |||
| } | |||
| } | |||
| @@ -162,6 +162,7 @@ namespace Tensorflow | |||
| this.op = op; | |||
| this.value_index = value_index; | |||
| this._dtype = dtype; | |||
| _id = ops.uid(); | |||
| } | |||
| public List<Operation> consumers() | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| @@ -69,14 +70,15 @@ namespace Tensorflow | |||
| { | |||
| } | |||
| // Or get the initial value from a Tensor or Python object. | |||
| else | |||
| { | |||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||
| } | |||
| var shape = _initial_value.shape; | |||
| dtype = _initial_value.dtype; | |||
| _variable = gen_state_ops.variable_v2(shape, dtype, name); | |||
| var shape = _initial_value.shape; | |||
| dtype = _initial_value.dtype; | |||
| _variable = gen_state_ops.variable_v2(shape, dtype, name); | |||
| } | |||
| // Manually overrides the variable's shape with the initial value's. | |||
| if (validate_shape) | |||
| @@ -87,8 +89,9 @@ namespace Tensorflow | |||
| // If 'initial_value' makes use of other variables, make sure we don't | |||
| // have an issue if these other variables aren't initialized first by | |||
| // using their initialized_value() method. | |||
| var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value); | |||
| _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; | |||
| _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; | |||
| if (!String.IsNullOrEmpty(caching_device)) | |||
| { | |||
| @@ -112,5 +115,51 @@ namespace Tensorflow | |||
| { | |||
| return _snapshot; | |||
| } | |||
| /// <summary> | |||
| /// Attempt to guard against dependencies on uninitialized variables. | |||
| /// </summary> | |||
| /// <param name="initial_value"></param> | |||
| private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value) | |||
| { | |||
| return _safe_initial_value_from_tensor(initial_value, new Dictionary<string, Operation>()); | |||
| } | |||
| /// <summary> | |||
| /// Replace dependencies on variables with their initialized values. | |||
| /// </summary> | |||
| /// <param name="tensor">A `Tensor`. The tensor to replace.</param> | |||
| /// <param name="op_cache">A dict mapping operation names to `Operation`s.</param> | |||
| /// <returns>A `Tensor` compatible with `tensor`.</returns> | |||
| private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary<string, Operation> op_cache) | |||
| { | |||
| var op = tensor.op; | |||
| var new_op = op_cache.ContainsKey(op.Name) ? op_cache[op.Name] : null; | |||
| if(new_op == null) | |||
| { | |||
| new_op = _safe_initial_value_from_op(op, op_cache); | |||
| op_cache[op.Name] = new_op; | |||
| } | |||
| return new_op.outputs[tensor.value_index]; | |||
| } | |||
| private Operation _safe_initial_value_from_op(Operation op, Dictionary<string, Operation> op_cache) | |||
| { | |||
| var op_type = op.node_def.Op; | |||
| switch (op_type) | |||
| { | |||
| case "IsVariableInitialized": | |||
| case "VarIsInitializedOp": | |||
| case "ReadVariableOp": | |||
| return op; | |||
| case "Variable": | |||
| case "VariableV2": | |||
| case "VarHandleOp": | |||
| break; | |||
| } | |||
| // Recursively build initializer expressions for inputs. | |||
| return op; | |||
| } | |||
| } | |||
| } | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||
| _execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name); | |||
| return new Tensor(_op, 0, dtype); | |||
| return _result[0]; | |||
| } | |||
| /// <summary> | |||
| @@ -130,7 +130,7 @@ namespace TensorFlowNET.UnitTest | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Serialize to NodeDef. | |||
| var node_def = neg.GetNodeDef(); | |||
| var node_def = neg.node_def; | |||
| // Validate NodeDef is what we expect. | |||
| ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); | |||
| @@ -145,13 +145,13 @@ namespace TensorFlowNET.UnitTest | |||
| // Look up some nodes by name. | |||
| Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | |||
| EXPECT_EQ(neg, neg2); | |||
| var node_def2 = neg2.GetNodeDef(); | |||
| var node_def2 = neg2.node_def; | |||
| EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||
| Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | |||
| EXPECT_EQ(feed, feed2); | |||
| node_def = feed.GetNodeDef(); | |||
| node_def2 = feed2.GetNodeDef(); | |||
| node_def = feed.node_def; | |||
| node_def2 = feed2.node_def; | |||
| EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||
| // Test iterating through the nodes of a graph. | |||
| @@ -186,7 +186,7 @@ namespace TensorFlowNET.UnitTest | |||
| } | |||
| else | |||
| { | |||
| node_def = oper.GetNodeDef(); | |||
| node_def = oper.node_def; | |||
| Assert.Fail($"Unexpected Node: {node_def.ToString()}"); | |||
| } | |||
| } | |||