| @@ -98,35 +98,23 @@ namespace Tensorflow | |||||
| default: | default: | ||||
| return obj?.ToString() ?? "null"; | return obj?.ToString() ?? "null"; | ||||
| } | } | ||||
| object[] toObjectArray(Array arr) | |||||
| { | |||||
| var len = arr.LongLength; | |||||
| var ret = new object[len]; | |||||
| for (long i = 0; i < len; i++) | |||||
| { | |||||
| ret[i] = arr.GetValue(i); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| private static TextWriter writer = null; | |||||
| private static TextWriter _writer = Console.Out; | |||||
| public static TextWriter tf_output_redirect { | public static TextWriter tf_output_redirect { | ||||
| set | set | ||||
| { | { | ||||
| var originWriter = writer ?? Console.Out; | |||||
| originWriter.Flush(); | |||||
| if (originWriter is StringWriter) | |||||
| (originWriter as StringWriter).GetStringBuilder().Clear(); | |||||
| writer = value; | |||||
| } | |||||
| get | |||||
| { | |||||
| return writer ?? Console.Out; | |||||
| if(_writer != null) | |||||
| { | |||||
| _writer.Flush(); | |||||
| if (_writer is StringWriter sw) | |||||
| sw.GetStringBuilder().Clear(); | |||||
| } | |||||
| _writer = value; | |||||
| } | } | ||||
| get => _writer ?? Console.Out; | |||||
| } | } | ||||
| public static void print(object obj) | public static void print(object obj) | ||||
| @@ -48,7 +48,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| // free unmanaged memory | // free unmanaged memory | ||||
| // if (_handle != IntPtr.Zero) | |||||
| if (_handle != IntPtr.Zero) | |||||
| { | { | ||||
| // Call the appropriate methods to clean up | // Call the appropriate methods to clean up | ||||
| // unmanaged resources here. | // unmanaged resources here. | ||||
| @@ -14,32 +14,23 @@ namespace Tensorflow.Keras.Engine | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) | public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) | ||||
| { | { | ||||
| callContext = callContext?.Value != null ? callContext : new ThreadLocal<CallContext>() | |||||
| { | |||||
| Value = new CallContext() | |||||
| }; | |||||
| if (callContext.Value == null) | |||||
| callContext.Value = new CallContext(); | |||||
| if (_in_functional_construction_mode(inputs)) | if (_in_functional_construction_mode(inputs)) | ||||
| return FunctionalConstructionCall(inputs); | return FunctionalConstructionCall(inputs); | ||||
| Tensors outputs = null; | |||||
| var eager = tf.executing_eagerly(); | var eager = tf.executing_eagerly(); | ||||
| using var ctxManager = CallContext.enter(build_graph: false); | using var ctxManager = CallContext.enter(build_graph: false); | ||||
| string nameScope = ""; | |||||
| if (eager) | |||||
| nameScope = Name; | |||||
| else | |||||
| nameScope = _name_scope(); | |||||
| string nameScope = eager ? name : _name_scope(); | |||||
| var scope = ops.name_scope(nameScope); | var scope = ops.name_scope(nameScope); | ||||
| scope.__enter__(); | scope.__enter__(); | ||||
| if (!built) | if (!built) | ||||
| MaybeBuild(inputs); | MaybeBuild(inputs); | ||||
| outputs = Call(inputs, state: state, training: training); | |||||
| var outputs = Call(inputs, state: state, training: training); | |||||
| // memory leak | // memory leak | ||||
| // _set_connectivity_metadata_(inputs, outputs); | // _set_connectivity_metadata_(inputs, outputs); | ||||
| @@ -84,11 +84,13 @@ namespace Tensorflow.Keras.Engine | |||||
| List<INode> outboundNodes; | List<INode> outboundNodes; | ||||
| public List<INode> OutboundNodes => outboundNodes; | public List<INode> OutboundNodes => outboundNodes; | ||||
| ThreadLocal<CallContext> callContext; | |||||
| ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | |||||
| public CallContext CallContext => callContext.Value; | public CallContext CallContext => callContext.Value; | ||||
| public Tensor[] input => inboundNodes[0].input_tensors; | public Tensor[] input => inboundNodes[0].input_tensors; | ||||
| public Dictionary<int, List<INode>> NodesByDepth { get; set; } | public Dictionary<int, List<INode>> NodesByDepth { get; set; } | ||||
| public Shape output_shape => inboundNodes[0].Outputs.shape; | public Shape output_shape => inboundNodes[0].Outputs.shape; | ||||
| protected List<ILayer> _self_tracked_trackables; | |||||
| public Layer(LayerArgs args) | public Layer(LayerArgs args) | ||||
| { | { | ||||
| this.args = args; | this.args = args; | ||||
| @@ -106,6 +108,7 @@ namespace Tensorflow.Keras.Engine | |||||
| non_trainable_weights = new List<IVariableV1>(); | non_trainable_weights = new List<IVariableV1>(); | ||||
| computePreviousMask = false; | computePreviousMask = false; | ||||
| updates = new List<Operation>(); | updates = new List<Operation>(); | ||||
| _self_tracked_trackables = new List<ILayer>(); | |||||
| inboundNodes = new List<INode>(); | inboundNodes = new List<INode>(); | ||||
| outboundNodes = new List<INode>(); | outboundNodes = new List<INode>(); | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| @@ -35,8 +36,9 @@ namespace Tensorflow.Keras.Engine | |||||
| bool _auto_track_sub_layers; | bool _auto_track_sub_layers; | ||||
| Shape _inferred_input_shape; | Shape _inferred_input_shape; | ||||
| bool _has_explicit_input_shape; | bool _has_explicit_input_shape; | ||||
| bool _graph_initialized; | |||||
| public Shape output_shape => outputs[0].shape; | public Shape output_shape => outputs[0].shape; | ||||
| List<INode> _created_nodes; | |||||
| public Sequential(SequentialArgs args) | public Sequential(SequentialArgs args) | ||||
| : base(args.Inputs, args.Outputs, name: args.Name) | : base(args.Inputs, args.Outputs, name: args.Name) | ||||
| @@ -49,12 +51,13 @@ namespace Tensorflow.Keras.Engine | |||||
| _auto_track_sub_layers = false; | _auto_track_sub_layers = false; | ||||
| _has_explicit_input_shape = false; | _has_explicit_input_shape = false; | ||||
| _is_graph_network = false; | _is_graph_network = false; | ||||
| _created_nodes = new List<INode>(); | |||||
| // Add to the model any layers passed to the constructor. | // Add to the model any layers passed to the constructor. | ||||
| if (args.Layers != null) | if (args.Layers != null) | ||||
| { | { | ||||
| foreach (var layer in args.Layers) | foreach (var layer in args.Layers) | ||||
| add(layer as Layer); | |||||
| add(layer); | |||||
| } | } | ||||
| } | } | ||||
| @@ -118,7 +121,69 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| _self_tracked_trackables.add(layer); | |||||
| _handle_deferred_layer_dependencies(layer); | |||||
| } | |||||
| } | |||||
| void _handle_deferred_layer_dependencies(params ILayer[] layers) | |||||
| { | |||||
| _layers.AddRange(layers); | |||||
| } | |||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
| { | |||||
| if (!_has_explicit_input_shape) | |||||
| { | |||||
| _build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype); | |||||
| } | |||||
| if(_graph_initialized) | |||||
| { | |||||
| if (!built) | |||||
| _init_graph_network(this.inputs, outputs); | |||||
| return base.Call(inputs, state, training); | |||||
| } | |||||
| return base.Call(inputs, state, training); | |||||
| } | |||||
| void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) | |||||
| { | |||||
| ops.init_scope(); | |||||
| var inputs = keras.Input(batch_input_shape: input_shape, | |||||
| dtype: input_dtype, | |||||
| name: $"{_layers[0].Name}_input"); | |||||
| Tensors layer_input = inputs; | |||||
| Tensors layer_output = null; | |||||
| Tensors outputs = null; | |||||
| foreach (var layer in _layers) | |||||
| { | |||||
| clear_previously_created_nodes(layer, _created_nodes); | |||||
| layer_output = layer.Apply(layer_input); | |||||
| // Keep track of nodes just created above | |||||
| track_nodes_created_by_last_call(layer, _created_nodes); | |||||
| layer_input = layer_output; | |||||
| outputs = layer_output; | |||||
| } | |||||
| _init_graph_network(inputs, outputs); | |||||
| _graph_initialized = true; | |||||
| _inferred_input_shape = input_shape; | |||||
| } | |||||
| void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes) | |||||
| { | |||||
| } | |||||
| void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes) | |||||
| { | |||||
| var node = layer.InboundNodes.Last(); | |||||
| created_nodes.Add(node); | |||||
| foreach(var prev_layer in node.iterate_inbound()) | |||||
| { | |||||
| created_nodes.add(prev_layer.Item1.OutboundNodes.Last()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Layers | |||||
| var rank = inputs.rank; | var rank = inputs.rank; | ||||
| if (rank > 2) | if (rank > 2) | ||||
| { | { | ||||
| throw new NotImplementedException("call rank > 2"); | |||||
| outputs = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { rank - 1 }, { 0 } }); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||