From 1989988290d01306d73e84f0e2d64d8225deb7c4 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 21 Jan 2019 12:23:11 -0600 Subject: [PATCH] fix Operation.inputs value is null. #117 --- src/TensorFlowNET.Core/Graphs/Graph.cs | 18 +++++++- .../Operations/Operation.cs | 44 +++++++++++++++---- .../Variables/gen_state_ops.py.cs | 4 +- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index cb5784ce..f8ef8900 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -118,11 +118,17 @@ namespace Tensorflow name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); + if (inputs == null) + inputs = new List(); + + var input_ops = inputs.Select(x => x.op).ToArray(); + var control_inputs = _control_dependencies_for_inputs(input_ops); + var op = new Operation(node_def, this, inputs: inputs, output_types: dtypes, - control_inputs: new object[] { }, + control_inputs: control_inputs, input_types: input_types, original_op: null, op_def: op_def); @@ -131,6 +137,16 @@ namespace Tensorflow return op; } + /// + /// For an op that takes `input_ops` as inputs, compute control inputs. + /// + /// The data input ops for an op to be created. + /// A list of control inputs for the op to be created. + private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) + { + return new Operation[0]; + } + private void _create_op_helper(Operation op, bool compute_device = true) { diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index a4fd5838..00ffb76f 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -85,8 +85,42 @@ namespace Tensorflow } private Tensor[] _outputs; - public Tensor[] outputs => _outputs; - public Tensor[] inputs; + public Tensor[] outputs + { + get + { + if(_outputs == null) + { + _outputs = new Tensor[NumOutputs]; + + for (int i = 0; i < NumOutputs; i++) + _outputs[i] = new Tensor(this, i, OutputType(i)); + } + + return _outputs; + } + } + + private Tensor[] _inputs; + public Tensor[] inputs + { + get + { + if(_inputs == null) + { + _inputs = new Tensor[NumInputs]; + + for (int i = 0; i < NumInputs; i++) + { + var tf_outpus = Input(i); + var op = new Operation(tf_outpus.oper); + _inputs[i] = op.outputs[tf_outpus.index]; + } + } + + return _inputs; + } + } public Operation(IntPtr handle) { @@ -115,14 +149,10 @@ namespace Tensorflow _handle = ops._create_c_op(g, node_def, inputs); - _outputs = new Tensor[NumOutputs]; output_types = new TF_DataType[NumOutputs]; for (int i = 0; i < NumOutputs; i++) - { output_types[i] = OutputType(i); - _outputs[i] = new Tensor(this, i, output_types[i]); - } Graph._add_op(this); } @@ -131,8 +161,6 @@ namespace Tensorflow { AttrValue x = null; - var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" }; - using (var buf = new Buffer()) { c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index a36edec6..aa6170e3 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -40,7 +40,7 @@ namespace Tensorflow _attrs["container"] = _op.get_attr("container"); _attrs["shared_name"] = _op.get_attr("shared_name"); - _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); + _execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name); return new Tensor(_op, 0, dtype); } @@ -74,7 +74,7 @@ namespace Tensorflow _attrs["validate_shape"] = _op.get_attr("validate_shape"); _attrs["use_locking"] = _op.get_attr("use_locking"); - _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); + _execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name); return _result[0]; }