| @@ -6,12 +6,9 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| public class Execute | public class Execute | ||||
| { | { | ||||
| public void record_gradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | |||||
| public void record_gradient(string op_name, InputList inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | |||||
| { | { | ||||
| if (inputs == null) | |||||
| inputs = new Tensor[0]; | |||||
| pywrap_tfe_src.RecordGradient(op_name, inputs, attrs, results, name); | |||||
| pywrap_tfe_src.RecordGradient(op_name, inputs._inputs, attrs, results, name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -25,6 +25,9 @@ namespace Tensorflow.Eager | |||||
| } | } | ||||
| } | } | ||||
| if (!should_record) return; | if (!should_record) return; | ||||
| var op_outputs = results; | |||||
| var op_inputs = inputs; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class InputList | |||||
| { | |||||
| public Tensor[] _inputs; | |||||
| public InputList(Tensor[] inputs) | |||||
| { | |||||
| _inputs = inputs; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -101,21 +101,23 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private Tensor[] _inputs; | |||||
| public Tensor[] inputs | |||||
| private InputList _inputs; | |||||
| public InputList inputs | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if(_inputs == null) | if(_inputs == null) | ||||
| { | { | ||||
| _inputs = new Tensor[NumInputs]; | |||||
| var retval = new Tensor[NumInputs]; | |||||
| for (int i = 0; i < NumInputs; i++) | for (int i = 0; i < NumInputs; i++) | ||||
| { | { | ||||
| var tf_outpus = Input(i); | var tf_outpus = Input(i); | ||||
| var op = new Operation(tf_outpus.oper); | var op = new Operation(tf_outpus.oper); | ||||
| _inputs[i] = op.outputs[tf_outpus.index]; | |||||
| retval[i] = op.outputs[tf_outpus.index]; | |||||
| } | } | ||||
| _inputs = new InputList(retval); | |||||
| } | } | ||||
| return _inputs; | return _inputs; | ||||