Browse Source

fix Operation.inputs value is null. #117

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
1989988290
3 changed files with 55 additions and 11 deletions
  1. +17
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +36
    -8
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

+ 17
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -118,11 +118,17 @@ namespace Tensorflow
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);


if (inputs == null)
inputs = new List<Tensor>();

var input_ops = inputs.Select(x => x.op).ToArray();
var control_inputs = _control_dependencies_for_inputs(input_ops);

var op = new Operation(node_def, var op = new Operation(node_def,
this, this,
inputs: inputs, inputs: inputs,
output_types: dtypes, output_types: dtypes,
control_inputs: new object[] { },
control_inputs: control_inputs,
input_types: input_types, input_types: input_types,
original_op: null, original_op: null,
op_def: op_def); op_def: op_def);
@@ -131,6 +137,16 @@ namespace Tensorflow
return op; return op;
} }


/// <summary>
/// For an op that takes `input_ops` as inputs, compute control inputs.
/// </summary>
/// <param name="input_ops">The data input ops for an op to be created.</param>
/// <returns>A list of control inputs for the op to be created.</returns>
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
{
return new Operation[0];
}

private void _create_op_helper(Operation op, bool compute_device = true) private void _create_op_helper(Operation op, bool compute_device = true)
{ {




+ 36
- 8
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -85,8 +85,42 @@ namespace Tensorflow
} }


private Tensor[] _outputs; private Tensor[] _outputs;
public Tensor[] outputs => _outputs;
public Tensor[] inputs;
public Tensor[] outputs
{
get
{
if(_outputs == null)
{
_outputs = new Tensor[NumOutputs];

for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
}

return _outputs;
}
}

private Tensor[] _inputs;
public Tensor[] inputs
{
get
{
if(_inputs == null)
{
_inputs = new Tensor[NumInputs];

for (int i = 0; i < NumInputs; i++)
{
var tf_outpus = Input(i);
var op = new Operation(tf_outpus.oper);
_inputs[i] = op.outputs[tf_outpus.index];
}
}

return _inputs;
}
}


public Operation(IntPtr handle) public Operation(IntPtr handle)
{ {
@@ -115,14 +149,10 @@ namespace Tensorflow


_handle = ops._create_c_op(g, node_def, inputs); _handle = ops._create_c_op(g, node_def, inputs);


_outputs = new Tensor[NumOutputs];
output_types = new TF_DataType[NumOutputs]; output_types = new TF_DataType[NumOutputs];


for (int i = 0; i < NumOutputs; i++) for (int i = 0; i < NumOutputs; i++)
{
output_types[i] = OutputType(i); output_types[i] = OutputType(i);
_outputs[i] = new Tensor(this, i, output_types[i]);
}


Graph._add_op(this); Graph._add_op(this);
} }
@@ -131,8 +161,6 @@ namespace Tensorflow
{ {
AttrValue x = null; AttrValue x = null;


var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" };

using (var buf = new Buffer()) using (var buf = new Buffer())
{ {
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);


+ 2
- 2
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -40,7 +40,7 @@ namespace Tensorflow
_attrs["container"] = _op.get_attr<string>("container"); _attrs["container"] = _op.get_attr<string>("container");
_attrs["shared_name"] = _op.get_attr<string>("shared_name"); _attrs["shared_name"] = _op.get_attr<string>("shared_name");


_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
_execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name);


return new Tensor(_op, 0, dtype); return new Tensor(_op, 0, dtype);
} }
@@ -74,7 +74,7 @@ namespace Tensorflow
_attrs["validate_shape"] = _op.get_attr<bool>("validate_shape"); _attrs["validate_shape"] = _op.get_attr<bool>("validate_shape");
_attrs["use_locking"] = _op.get_attr<bool>("use_locking"); _attrs["use_locking"] = _op.get_attr<bool>("use_locking");


_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
_execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name);


return _result[0]; return _result[0];
} }


Loading…
Cancel
Save