Browse Source

Add InboundLayers to Node

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
d1fc44dcef
8 changed files with 42 additions and 7 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs
  2. +1
    -0
      src/TensorFlowNET.Core/Keras/Engine/INode.cs
  3. +7
    -0
      src/TensorFlowNET.Core/NumPy/ShapeHelper.cs
  4. +6
    -0
      src/TensorFlowNET.Core/Numpy/Shape.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  6. +2
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  7. +4
    -0
      src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs
  8. +20
    -4
      src/TensorFlowNET.Keras/Engine/Sequential.cs

+ 1
- 1
src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs View File

@@ -78,7 +78,7 @@ namespace Tensorflow.Contexts
if (args.GetGradientAttrs == null) if (args.GetGradientAttrs == null)
{ {
attrs = new Dictionary<string, object>(); attrs = new Dictionary<string, object>();
attrs["T"] = op.get_attr<TF_DataType>("T");
attrs["T"] = op.dtype;
} }
else else
{ {


+ 1
- 0
src/TensorFlowNET.Core/Keras/Engine/INode.cs View File

@@ -11,6 +11,7 @@ namespace Tensorflow.Keras.Engine
ILayer Layer { get; } ILayer Layer { get; }
List<Tensor> KerasInputs { get; set; } List<Tensor> KerasInputs { get; set; }
INode[] ParentNodes { get; } INode[] ParentNodes { get; }
ILayer[] InboundLayers { get; }
IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound();
bool is_input { get; } bool is_input { get; }
List<NodeConfig> serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map); List<NodeConfig> serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map);


+ 7
- 0
src/TensorFlowNET.Core/NumPy/ShapeHelper.cs View File

@@ -88,6 +88,13 @@ namespace Tensorflow.NumPy


public static bool Equals(Shape shape, object target) public static bool Equals(Shape shape, object target)
{ {
if (shape is null && target is null)
return true;
else if (shape is null && target is not null)
return false;
else if (shape is not null && target is null)
return false;

switch (target) switch (target)
{ {
case Shape shape1: case Shape shape1:


+ 6
- 0
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -253,5 +253,11 @@ namespace Tensorflow
public override bool Equals(object obj) => ShapeHelper.Equals(this, obj); public override bool Equals(object obj) => ShapeHelper.Equals(this, obj);


public override string ToString() => ShapeHelper.ToString(this); public override string ToString() => ShapeHelper.ToString(this);

public static bool operator ==(Shape a, Shape b)
=> ShapeHelper.Equals(a, b);

public static bool operator !=(Shape a, Shape b)
=> !ShapeHelper.Equals(a, b);
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow


public int _id_value { get; set; } public int _id_value { get; set; }
public Operation op => this; public Operation op => this;
public TF_DataType dtype => TF_DataType.DtInvalid;
public TF_DataType dtype => output.dtype;
public virtual string name => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationName(_handle)); public virtual string name => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); public string OpType => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationOpType(_handle));




+ 2
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -69,7 +69,8 @@ namespace Tensorflow.Keras.Engine


NetworkNodes = nodes; NetworkNodes = nodes;
NodesByDepth = nodes_by_depth; NodesByDepth = nodes_by_depth;
_layers = layers;
if (_layers.Count == 0)
_layers = layers;


// Build self.input_names and self.output_names. // Build self.input_names and self.output_names.
_set_output_names(); _set_output_names();


+ 4
- 0
src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs View File

@@ -1,9 +1,13 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
public partial class Node public partial class Node
{ {
public ILayer[] InboundLayers
=> iterate_inbound().Select(x => x.Item1).ToArray();

public IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound() public IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound()
{ {
foreach (var kt in KerasInputs) foreach (var kt in KerasInputs)


+ 20
- 4
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -150,6 +150,9 @@ namespace Tensorflow.Keras.Engine


void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype)
{ {
if (_inferred_input_shape == input_shape)
return;

ops.init_scope(); ops.init_scope();
var inputs = keras.Input(batch_input_shape: input_shape, var inputs = keras.Input(batch_input_shape: input_shape,
dtype: input_dtype, dtype: input_dtype,
@@ -157,16 +160,17 @@ namespace Tensorflow.Keras.Engine
Tensors layer_input = inputs; Tensors layer_input = inputs;
Tensors layer_output = null; Tensors layer_output = null;
Tensors outputs = null; Tensors outputs = null;
List<INode> created_nodes = new List<INode>();
foreach (var layer in _layers) foreach (var layer in _layers)
{ {
clear_previously_created_nodes(layer, _created_nodes); clear_previously_created_nodes(layer, _created_nodes);
layer_output = layer.Apply(layer_input); layer_output = layer.Apply(layer_input);
// Keep track of nodes just created above // Keep track of nodes just created above
track_nodes_created_by_last_call(layer, _created_nodes);
track_nodes_created_by_last_call(layer, created_nodes);
layer_input = layer_output; layer_input = layer_output;
outputs = layer_output; outputs = layer_output;
} }
_created_nodes = created_nodes;
_init_graph_network(inputs, outputs); _init_graph_network(inputs, outputs);
_graph_initialized = true; _graph_initialized = true;
_inferred_input_shape = input_shape; _inferred_input_shape = input_shape;
@@ -174,16 +178,28 @@ namespace Tensorflow.Keras.Engine


void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes) void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes)
{ {
foreach(var node in layer.InboundNodes)
{
foreach(var prev_layer in node.InboundLayers)
{
var outNodes = prev_layer.OutboundNodes.Where(x => !created_nodes.Contains(x)).ToArray();
prev_layer.OutboundNodes.Clear();
prev_layer.OutboundNodes.AddRange(outNodes);
}
}


var inNodes = layer.InboundNodes.Where(x => !created_nodes.Contains(x)).ToArray();
layer.InboundNodes.Clear();
layer.InboundNodes.AddRange(inNodes);
} }


void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes) void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes)
{ {
var node = layer.InboundNodes.Last(); var node = layer.InboundNodes.Last();
created_nodes.Add(node); created_nodes.Add(node);
foreach(var prev_layer in node.iterate_inbound())
foreach(var prev_layer in node.InboundLayers)
{ {
created_nodes.add(prev_layer.Item1.OutboundNodes.Last());
created_nodes.add(prev_layer.OutboundNodes.Last());
} }
} }
} }


Loading…
Cancel
Save