| @@ -118,11 +118,17 @@ namespace Tensorflow | |||||
| name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | ||||
| var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | ||||
| if (inputs == null) | |||||
| inputs = new List<Tensor>(); | |||||
| var input_ops = inputs.Select(x => x.op).ToArray(); | |||||
| var control_inputs = _control_dependencies_for_inputs(input_ops); | |||||
| var op = new Operation(node_def, | var op = new Operation(node_def, | ||||
| this, | this, | ||||
| inputs: inputs, | inputs: inputs, | ||||
| output_types: dtypes, | output_types: dtypes, | ||||
| control_inputs: new object[] { }, | |||||
| control_inputs: control_inputs, | |||||
| input_types: input_types, | input_types: input_types, | ||||
| original_op: null, | original_op: null, | ||||
| op_def: op_def); | op_def: op_def); | ||||
| @@ -131,6 +137,16 @@ namespace Tensorflow | |||||
| return op; | return op; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// For an op that takes `input_ops` as inputs, compute control inputs. | |||||
| /// </summary> | |||||
| /// <param name="input_ops">The data input ops for an op to be created.</param> | |||||
| /// <returns>A list of control inputs for the op to be created.</returns> | |||||
| private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) | |||||
| { | |||||
| return new Operation[0]; | |||||
| } | |||||
| private void _create_op_helper(Operation op, bool compute_device = true) | private void _create_op_helper(Operation op, bool compute_device = true) | ||||
| { | { | ||||
| @@ -85,8 +85,42 @@ namespace Tensorflow | |||||
| } | } | ||||
| private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
| public Tensor[] outputs => _outputs; | |||||
| public Tensor[] inputs; | |||||
| public Tensor[] outputs | |||||
| { | |||||
| get | |||||
| { | |||||
| if(_outputs == null) | |||||
| { | |||||
| _outputs = new Tensor[NumOutputs]; | |||||
| for (int i = 0; i < NumOutputs; i++) | |||||
| _outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
| } | |||||
| return _outputs; | |||||
| } | |||||
| } | |||||
| private Tensor[] _inputs; | |||||
| public Tensor[] inputs | |||||
| { | |||||
| get | |||||
| { | |||||
| if(_inputs == null) | |||||
| { | |||||
| _inputs = new Tensor[NumInputs]; | |||||
| for (int i = 0; i < NumInputs; i++) | |||||
| { | |||||
| var tf_outpus = Input(i); | |||||
| var op = new Operation(tf_outpus.oper); | |||||
| _inputs[i] = op.outputs[tf_outpus.index]; | |||||
| } | |||||
| } | |||||
| return _inputs; | |||||
| } | |||||
| } | |||||
| public Operation(IntPtr handle) | public Operation(IntPtr handle) | ||||
| { | { | ||||
| @@ -115,14 +149,10 @@ namespace Tensorflow | |||||
| _handle = ops._create_c_op(g, node_def, inputs); | _handle = ops._create_c_op(g, node_def, inputs); | ||||
| _outputs = new Tensor[NumOutputs]; | |||||
| output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
| for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
| { | |||||
| output_types[i] = OutputType(i); | output_types[i] = OutputType(i); | ||||
| _outputs[i] = new Tensor(this, i, output_types[i]); | |||||
| } | |||||
| Graph._add_op(this); | Graph._add_op(this); | ||||
| } | } | ||||
| @@ -131,8 +161,6 @@ namespace Tensorflow | |||||
| { | { | ||||
| AttrValue x = null; | AttrValue x = null; | ||||
| var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" }; | |||||
| using (var buf = new Buffer()) | using (var buf = new Buffer()) | ||||
| { | { | ||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | ||||
| @@ -40,7 +40,7 @@ namespace Tensorflow | |||||
| _attrs["container"] = _op.get_attr<string>("container"); | _attrs["container"] = _op.get_attr<string>("container"); | ||||
| _attrs["shared_name"] = _op.get_attr<string>("shared_name"); | _attrs["shared_name"] = _op.get_attr<string>("shared_name"); | ||||
| _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | |||||
| _execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name); | |||||
| return new Tensor(_op, 0, dtype); | return new Tensor(_op, 0, dtype); | ||||
| } | } | ||||
| @@ -74,7 +74,7 @@ namespace Tensorflow | |||||
| _attrs["validate_shape"] = _op.get_attr<bool>("validate_shape"); | _attrs["validate_shape"] = _op.get_attr<bool>("validate_shape"); | ||||
| _attrs["use_locking"] = _op.get_attr<bool>("use_locking"); | _attrs["use_locking"] = _op.get_attr<bool>("use_locking"); | ||||
| _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | |||||
| _execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name); | |||||
| return _result[0]; | return _result[0]; | ||||
| } | } | ||||