Browse Source

fix Operation.inputs value is null. #117

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
1989988290
3 changed files with 55 additions and 11 deletions
  1. +17
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +36
    -8
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

+ 17
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -118,11 +118,17 @@ namespace Tensorflow
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);

if (inputs == null)
inputs = new List<Tensor>();

var input_ops = inputs.Select(x => x.op).ToArray();
var control_inputs = _control_dependencies_for_inputs(input_ops);

var op = new Operation(node_def,
this,
inputs: inputs,
output_types: dtypes,
control_inputs: new object[] { },
control_inputs: control_inputs,
input_types: input_types,
original_op: null,
op_def: op_def);
@@ -131,6 +137,16 @@ namespace Tensorflow
return op;
}

/// <summary>
/// For an op that takes `input_ops` as inputs, compute control inputs.
/// </summary>
/// <param name="input_ops">The data input ops for an op to be created.</param>
/// <returns>A list of control inputs for the op to be created.</returns>
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
{
return new Operation[0];
}

private void _create_op_helper(Operation op, bool compute_device = true)
{



+ 36
- 8
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -85,8 +85,42 @@ namespace Tensorflow
}

private Tensor[] _outputs;
public Tensor[] outputs => _outputs;
public Tensor[] inputs;
public Tensor[] outputs
{
get
{
if(_outputs == null)
{
_outputs = new Tensor[NumOutputs];

for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
}

return _outputs;
}
}

private Tensor[] _inputs;
public Tensor[] inputs
{
get
{
if(_inputs == null)
{
_inputs = new Tensor[NumInputs];

for (int i = 0; i < NumInputs; i++)
{
var tf_outpus = Input(i);
var op = new Operation(tf_outpus.oper);
_inputs[i] = op.outputs[tf_outpus.index];
}
}

return _inputs;
}
}

public Operation(IntPtr handle)
{
@@ -115,14 +149,10 @@ namespace Tensorflow

_handle = ops._create_c_op(g, node_def, inputs);

_outputs = new Tensor[NumOutputs];
output_types = new TF_DataType[NumOutputs];

for (int i = 0; i < NumOutputs; i++)
{
output_types[i] = OutputType(i);
_outputs[i] = new Tensor(this, i, output_types[i]);
}

Graph._add_op(this);
}
@@ -131,8 +161,6 @@ namespace Tensorflow
{
AttrValue x = null;

var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" };

using (var buf = new Buffer())
{
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);


+ 2
- 2
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -40,7 +40,7 @@ namespace Tensorflow
_attrs["container"] = _op.get_attr<string>("container");
_attrs["shared_name"] = _op.get_attr<string>("shared_name");

_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
_execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name);

return new Tensor(_op, 0, dtype);
}
@@ -74,7 +74,7 @@ namespace Tensorflow
_attrs["validate_shape"] = _op.get_attr<bool>("validate_shape");
_attrs["use_locking"] = _op.get_attr<bool>("use_locking");

_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
_execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name);

return _result[0];
}


Loading…
Cancel
Save