Browse Source

fix keras sequential.

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
18d2512ee5
6 changed files with 87 additions and 40 deletions
  1. +10
    -22
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/DisposableObject.cs
  3. +4
    -13
      src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
  4. +4
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  5. +67
    -2
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  6. +1
    -1
      src/TensorFlowNET.Keras/Layers/Core/Dense.cs

+ 10
- 22
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -98,35 +98,23 @@ namespace Tensorflow
default: default:
return obj?.ToString() ?? "null"; return obj?.ToString() ?? "null";
} }

object[] toObjectArray(Array arr)
{
var len = arr.LongLength;
var ret = new object[len];
for (long i = 0; i < len; i++)
{
ret[i] = arr.GetValue(i);
}

return ret;
}
} }


private static TextWriter writer = null;
private static TextWriter _writer = Console.Out;


public static TextWriter tf_output_redirect { public static TextWriter tf_output_redirect {
set set
{ {
var originWriter = writer ?? Console.Out;
originWriter.Flush();
if (originWriter is StringWriter)
(originWriter as StringWriter).GetStringBuilder().Clear();
writer = value;
}
get
{
return writer ?? Console.Out;
if(_writer != null)
{
_writer.Flush();
if (_writer is StringWriter sw)
sw.GetStringBuilder().Clear();
}

_writer = value;
} }
get => _writer ?? Console.Out;
} }


public static void print(object obj) public static void print(object obj)


+ 1
- 1
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow
} }


// free unmanaged memory // free unmanaged memory
// if (_handle != IntPtr.Zero)
if (_handle != IntPtr.Zero)
{ {
// Call the appropriate methods to clean up // Call the appropriate methods to clean up
// unmanaged resources here. // unmanaged resources here.


+ 4
- 13
src/TensorFlowNET.Keras/Engine/Layer.Apply.cs View File

@@ -14,32 +14,23 @@ namespace Tensorflow.Keras.Engine
/// <returns></returns> /// <returns></returns>
public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false)
{ {
callContext = callContext?.Value != null ? callContext : new ThreadLocal<CallContext>()
{
Value = new CallContext()
};
if (callContext.Value == null)
callContext.Value = new CallContext();


if (_in_functional_construction_mode(inputs)) if (_in_functional_construction_mode(inputs))
return FunctionalConstructionCall(inputs); return FunctionalConstructionCall(inputs);


Tensors outputs = null;

var eager = tf.executing_eagerly(); var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter(build_graph: false); using var ctxManager = CallContext.enter(build_graph: false);


string nameScope = "";
if (eager)
nameScope = Name;
else
nameScope = _name_scope();

string nameScope = eager ? name : _name_scope();
var scope = ops.name_scope(nameScope); var scope = ops.name_scope(nameScope);
scope.__enter__(); scope.__enter__();


if (!built) if (!built)
MaybeBuild(inputs); MaybeBuild(inputs);


outputs = Call(inputs, state: state, training: training);
var outputs = Call(inputs, state: state, training: training);


// memory leak // memory leak
// _set_connectivity_metadata_(inputs, outputs); // _set_connectivity_metadata_(inputs, outputs);


+ 4
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -84,11 +84,13 @@ namespace Tensorflow.Keras.Engine
List<INode> outboundNodes; List<INode> outboundNodes;
public List<INode> OutboundNodes => outboundNodes; public List<INode> OutboundNodes => outboundNodes;


ThreadLocal<CallContext> callContext;
ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>();
public CallContext CallContext => callContext.Value; public CallContext CallContext => callContext.Value;
public Tensor[] input => inboundNodes[0].input_tensors; public Tensor[] input => inboundNodes[0].input_tensors;
public Dictionary<int, List<INode>> NodesByDepth { get; set; } public Dictionary<int, List<INode>> NodesByDepth { get; set; }
public Shape output_shape => inboundNodes[0].Outputs.shape; public Shape output_shape => inboundNodes[0].Outputs.shape;
protected List<ILayer> _self_tracked_trackables;

public Layer(LayerArgs args) public Layer(LayerArgs args)
{ {
this.args = args; this.args = args;
@@ -106,6 +108,7 @@ namespace Tensorflow.Keras.Engine
non_trainable_weights = new List<IVariableV1>(); non_trainable_weights = new List<IVariableV1>();
computePreviousMask = false; computePreviousMask = false;
updates = new List<Operation>(); updates = new List<Operation>();
_self_tracked_trackables = new List<ILayer>();


inboundNodes = new List<INode>(); inboundNodes = new List<INode>();
outboundNodes = new List<INode>(); outboundNodes = new List<INode>();


+ 67
- 2
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System;
using System.Linq; using System.Linq;
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
@@ -35,8 +36,9 @@ namespace Tensorflow.Keras.Engine
bool _auto_track_sub_layers; bool _auto_track_sub_layers;
Shape _inferred_input_shape; Shape _inferred_input_shape;
bool _has_explicit_input_shape; bool _has_explicit_input_shape;
bool _graph_initialized;
public Shape output_shape => outputs[0].shape; public Shape output_shape => outputs[0].shape;
List<INode> _created_nodes;


public Sequential(SequentialArgs args) public Sequential(SequentialArgs args)
: base(args.Inputs, args.Outputs, name: args.Name) : base(args.Inputs, args.Outputs, name: args.Name)
@@ -49,12 +51,13 @@ namespace Tensorflow.Keras.Engine
_auto_track_sub_layers = false; _auto_track_sub_layers = false;
_has_explicit_input_shape = false; _has_explicit_input_shape = false;
_is_graph_network = false; _is_graph_network = false;
_created_nodes = new List<INode>();


// Add to the model any layers passed to the constructor. // Add to the model any layers passed to the constructor.
if (args.Layers != null) if (args.Layers != null)
{ {
foreach (var layer in args.Layers) foreach (var layer in args.Layers)
add(layer as Layer);
add(layer);
} }
} }


@@ -118,7 +121,69 @@ namespace Tensorflow.Keras.Engine
} }
else else
{ {
_self_tracked_trackables.add(layer);
_handle_deferred_layer_dependencies(layer);
}
}


void _handle_deferred_layer_dependencies(params ILayer[] layers)
{
_layers.AddRange(layers);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (!_has_explicit_input_shape)
{
_build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype);
}

if(_graph_initialized)
{
if (!built)
_init_graph_network(this.inputs, outputs);
return base.Call(inputs, state, training);
}

return base.Call(inputs, state, training);
}

void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype)
{
ops.init_scope();
var inputs = keras.Input(batch_input_shape: input_shape,
dtype: input_dtype,
name: $"{_layers[0].Name}_input");
Tensors layer_input = inputs;
Tensors layer_output = null;
Tensors outputs = null;
foreach (var layer in _layers)
{
clear_previously_created_nodes(layer, _created_nodes);
layer_output = layer.Apply(layer_input);
// Keep track of nodes just created above
track_nodes_created_by_last_call(layer, _created_nodes);
layer_input = layer_output;
outputs = layer_output;
}
_init_graph_network(inputs, outputs);
_graph_initialized = true;
_inferred_input_shape = input_shape;
}

void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes)
{

}

void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes)
{
var node = layer.InboundNodes.Last();
created_nodes.Add(node);
foreach(var prev_layer in node.iterate_inbound())
{
created_nodes.add(prev_layer.Item1.OutboundNodes.Last());
} }
} }
} }


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Core/Dense.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Layers
var rank = inputs.rank; var rank = inputs.rank;
if (rank > 2) if (rank > 2)
{ {
throw new NotImplementedException("call rank > 2");
outputs = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { rank - 1 }, { 0 } });
} }
else else
{ {


Loading…
Cancel
Save