diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs new file mode 100644 index 00000000..55987262 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -0,0 +1,60 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public partial class Operation + { + public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); + public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); + public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); + public int NumInputs => c_api.TF_OperationNumInputs(_handle); + private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); + + private InputList _inputs; + public InputList inputs + { + get + { + if (_inputs == null) + { + var retval = new Tensor[NumInputs]; + + for (int i = 0; i < NumInputs; i++) + { + var tf_outpus = Input(i); + var op = new Operation(tf_outpus.oper); + retval[i] = op.outputs[tf_outpus.index]; + } + + _inputs = new InputList(retval); + } + + return _inputs; + } + } + + public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); + + public unsafe Operation[] GetControlInputs() + { + var control_inputs = new Operation[NumControlInputs]; + + if (NumControlInputs > 0) + { + IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlInputs); + c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); + for (int i = 0; i < NumControlInputs; i++) + { + var handle = control_input_handle + Marshal.SizeOf() * i; + control_inputs[i] = new Operation(*(IntPtr*)handle); + } + } + + return control_inputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs new file mode 100644 index 00000000..5fdc6d47 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -0,0 +1,66 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public partial class Operation + { + public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); + public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); + public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); + + private Tensor[] _outputs; + public Tensor[] outputs + { + get + { + if (_outputs == null) + { + _outputs = new Tensor[NumOutputs]; + + for (int i = 0; i < NumOutputs; i++) + _outputs[i] = new Tensor(this, i, OutputType(i)); + } + + return _outputs; + } + } + + public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); + public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); + + public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) + { + int size = Marshal.SizeOf(); + var handle = Marshal.AllocHGlobal(size); + int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); + var consumers = new TF_Input[num]; + for (int i = 0; i < num; i++) + { + consumers[i] = Marshal.PtrToStructure(handle + i * size); + } + + return consumers; + } + + public unsafe Operation[] GetControlOutputs() + { + var control_outputs = new Operation[NumControlOutputs]; + + if (NumControlOutputs > 0) + { + IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlOutputs); + c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); + for (int i = 0; i < NumControlInputs; i++) + { + var handle = control_output_handle + Marshal.SizeOf() * i; + control_outputs[i] = new Operation(*(IntPtr*)handle); + } + } + + return control_outputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 352baf93..ef5d9813 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -6,7 +6,7 @@ using System.Text; namespace Tensorflow { - public class Operation + public partial class Operation { private readonly IntPtr _handle; @@ -20,112 +20,6 @@ namespace Tensorflow public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); - public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); - public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); - public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); - - public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); - public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); - public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); - public int NumInputs => c_api.TF_OperationNumInputs(_handle); - - public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); - public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) - { - int size = Marshal.SizeOf(); - var handle = Marshal.AllocHGlobal(size); - int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); - var consumers = new TF_Input[num]; - for (int i = 0; i < num; i++) - { - consumers[i] = Marshal.PtrToStructure(handle + i * size); - } - - return consumers; - } - - public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); - - public unsafe Operation[] GetControlInputs() - { - var control_inputs = new Operation[NumControlInputs]; - - if (NumControlInputs > 0) - { - IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlInputs); - c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); - for (int i = 0; i < NumControlInputs; i++) - { - var handle = control_input_handle + Marshal.SizeOf() * i; - control_inputs[i] = new Operation(*(IntPtr*)handle); - } - } - - return control_inputs; - } - - public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); - - public unsafe Operation[] GetControlOutputs() - { - var control_outputs = new Operation[NumControlOutputs]; - - if (NumControlOutputs > 0) - { - IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlOutputs); - c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); - for (int i = 0; i < NumControlInputs; i++) - { - var handle = control_output_handle + Marshal.SizeOf() * i; - control_outputs[i] = new Operation(*(IntPtr*)handle); - } - } - - return control_outputs; - } - - private Tensor[] _outputs; - public Tensor[] outputs - { - get - { - if (_outputs == null) - { - _outputs = new Tensor[NumOutputs]; - - for (int i = 0; i < NumOutputs; i++) - _outputs[i] = new Tensor(this, i, OutputType(i)); - } - - return _outputs; - } - } - - private InputList _inputs; - public InputList inputs - { - get - { - if (_inputs == null) - { - var retval = new Tensor[NumInputs]; - - for (int i = 0; i < NumInputs; i++) - { - var tf_outpus = Input(i); - var op = new Operation(tf_outpus.oper); - retval[i] = op.outputs[tf_outpus.index]; - } - - _inputs = new InputList(retval); - } - - return _inputs; - } - } - - private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); - private NodeDef _node_def; public NodeDef node_def { diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 3bf2cea9..69e4f39e 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -49,7 +49,7 @@ namespace TensorFlowNET.UnitTest using (var session = tf.Session()) { - var sm = session.run(model); + session.run(x.initializer); for(int i = 0; i < 5; i++) { var x1 = x + 1;