Browse Source

#173

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
596afe2059
7 changed files with 38 additions and 21 deletions
  1. +16
    -3
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +2
    -0
      src/TensorFlowNET.Core/ITensorOrOperation.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +16
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  5. +1
    -15
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  6. +1
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  7. +1
    -1
      test/TensorFlowNET.Examples/HelloWorld.cs

+ 16
- 3
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -41,7 +41,7 @@ namespace Tensorflow
_graph_key = $"grap-key-{ops.uid()}/";
}

public ITensorOrOperation as_graph_element(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true)
public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
{
return _as_graph_element_locked(obj, allow_tensor, allow_operation);
}
@@ -54,7 +54,7 @@ namespace Tensorflow
return null;
}

private ITensorOrOperation _as_graph_element_locked(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true)
private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
{
string types_str = "";

@@ -75,6 +75,19 @@ namespace Tensorflow
if (temp_obj != null)
obj = temp_obj;

// If obj appears to be a name...
if (obj is String str)
{
if(str.Contains(":") && allow_tensor)
{
string op_name = str.Split(':')[0];
int out_n = int.Parse(str.Split(':')[1]);

if (_nodes_by_name.ContainsKey(op_name))
return _nodes_by_name[op_name].outputs[out_n];
}
}

if (obj is Tensor tensor && allow_tensor)
{
if (tensor.Graph.Equals(this))
@@ -166,7 +179,7 @@ namespace Tensorflow
public void _add_op(Operation op)
{
_nodes_by_id[op._id] = op;
//_nodes_by_name[op.name] = op;
_nodes_by_name[op.name] = op;
_version = Math.Max(_version, op._id);
}



+ 2
- 0
src/TensorFlowNET.Core/ITensorOrOperation.cs View File

@@ -13,5 +13,7 @@ namespace Tensorflow
string Device { get; }
Operation op { get; }
string name { get; }
TF_DataType dtype { get; }
Tensor[] outputs { get; }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow

public string type => OpType;
public Operation op => this;
public TF_DataType dtype => TF_DataType.DtInvalid;
private Status status = new Status();

public string name => c_api.StringPiece(c_api.TF_OperationName(_handle));


+ 16
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -44,9 +44,24 @@ namespace Tensorflow
private NDArray _run(object fetches, FeedItem[] feed_dict = null)
{
var feed_dict_tensor = new Dictionary<object, object>();
var feed_map = new Dictionary<object, object>();

// Validate and process feed_dict.
if (feed_dict != null)
feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value));
{
foreach(var subfeed in feed_dict)
{
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
switch(subfeed.Value)
{
case string str:
feed_dict_tensor[subfeed_t] = np.array(str);
feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value);
break;
}
}
}

// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);


+ 1
- 15
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -20,21 +20,7 @@ namespace Tensorflow

foreach(var fetch in fetches)
{
switch(fetch)
{
case Tensor tensor:
el = g.as_graph_element(tensor, allow_tensor: true, allow_operation: true);
break;
case Operation op:
el = g.as_graph_element(op, allow_tensor: true, allow_operation: true);
break;
case String str:
// Looks like a Tensor name and can be a Tensor.
el = g._nodes_by_name[str];
break;
default:
throw new NotImplementedException("_ElementFetchMapper");
}
el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true);
}

_unique_fetches.Add(el);


+ 1
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -21,6 +21,7 @@ namespace Tensorflow

public Graph Graph => op?.Graph;
public Operation op { get; }
public Tensor[] outputs => op.outputs;

/// <summary>
/// The string name of this tensor.


+ 1
- 1
test/TensorFlowNET.Examples/HelloWorld.cs View File

@@ -18,7 +18,7 @@ namespace TensorFlowNET.Examples
The value returned by the constructor represents the output
of the Constant op. */
var str = "Hello, TensorFlow!";
var str = "Hello, TensorFlow.NET!";
var hello = tf.constant(str);

// Start tf session


Loading…
Cancel
Save