| @@ -98,35 +98,23 @@ namespace Tensorflow | |||
| default: | |||
| return obj?.ToString() ?? "null"; | |||
| } | |||
| object[] toObjectArray(Array arr) | |||
| { | |||
| var len = arr.LongLength; | |||
| var ret = new object[len]; | |||
| for (long i = 0; i < len; i++) | |||
| { | |||
| ret[i] = arr.GetValue(i); | |||
| } | |||
| return ret; | |||
| } | |||
| } | |||
| private static TextWriter writer = null; | |||
| private static TextWriter _writer = Console.Out; | |||
| public static TextWriter tf_output_redirect { | |||
| set | |||
| { | |||
| var originWriter = writer ?? Console.Out; | |||
| originWriter.Flush(); | |||
| if (originWriter is StringWriter) | |||
| (originWriter as StringWriter).GetStringBuilder().Clear(); | |||
| writer = value; | |||
| } | |||
| get | |||
| { | |||
| return writer ?? Console.Out; | |||
| if(_writer != null) | |||
| { | |||
| _writer.Flush(); | |||
| if (_writer is StringWriter sw) | |||
| sw.GetStringBuilder().Clear(); | |||
| } | |||
| _writer = value; | |||
| } | |||
| get => _writer ?? Console.Out; | |||
| } | |||
| public static void print(object obj) | |||
| @@ -48,7 +48,7 @@ namespace Tensorflow | |||
| } | |||
| // free unmanaged memory | |||
| // if (_handle != IntPtr.Zero) | |||
| if (_handle != IntPtr.Zero) | |||
| { | |||
| // Call the appropriate methods to clean up | |||
| // unmanaged resources here. | |||
| @@ -14,32 +14,23 @@ namespace Tensorflow.Keras.Engine | |||
| /// <returns></returns> | |||
| public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) | |||
| { | |||
| callContext = callContext?.Value != null ? callContext : new ThreadLocal<CallContext>() | |||
| { | |||
| Value = new CallContext() | |||
| }; | |||
| if (callContext.Value == null) | |||
| callContext.Value = new CallContext(); | |||
| if (_in_functional_construction_mode(inputs)) | |||
| return FunctionalConstructionCall(inputs); | |||
| Tensors outputs = null; | |||
| var eager = tf.executing_eagerly(); | |||
| using var ctxManager = CallContext.enter(build_graph: false); | |||
| string nameScope = ""; | |||
| if (eager) | |||
| nameScope = Name; | |||
| else | |||
| nameScope = _name_scope(); | |||
| string nameScope = eager ? name : _name_scope(); | |||
| var scope = ops.name_scope(nameScope); | |||
| scope.__enter__(); | |||
| if (!built) | |||
| MaybeBuild(inputs); | |||
| outputs = Call(inputs, state: state, training: training); | |||
| var outputs = Call(inputs, state: state, training: training); | |||
| // memory leak | |||
| // _set_connectivity_metadata_(inputs, outputs); | |||
| @@ -84,11 +84,13 @@ namespace Tensorflow.Keras.Engine | |||
| List<INode> outboundNodes; | |||
| public List<INode> OutboundNodes => outboundNodes; | |||
| ThreadLocal<CallContext> callContext; | |||
| ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | |||
| public CallContext CallContext => callContext.Value; | |||
| public Tensor[] input => inboundNodes[0].input_tensors; | |||
| public Dictionary<int, List<INode>> NodesByDepth { get; set; } | |||
| public Shape output_shape => inboundNodes[0].Outputs.shape; | |||
| protected List<ILayer> _self_tracked_trackables; | |||
| public Layer(LayerArgs args) | |||
| { | |||
| this.args = args; | |||
| @@ -106,6 +108,7 @@ namespace Tensorflow.Keras.Engine | |||
| non_trainable_weights = new List<IVariableV1>(); | |||
| computePreviousMask = false; | |||
| updates = new List<Operation>(); | |||
| _self_tracked_trackables = new List<ILayer>(); | |||
| inboundNodes = new List<INode>(); | |||
| outboundNodes = new List<INode>(); | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Linq; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| @@ -35,8 +36,9 @@ namespace Tensorflow.Keras.Engine | |||
| bool _auto_track_sub_layers; | |||
| Shape _inferred_input_shape; | |||
| bool _has_explicit_input_shape; | |||
| bool _graph_initialized; | |||
| public Shape output_shape => outputs[0].shape; | |||
| List<INode> _created_nodes; | |||
| public Sequential(SequentialArgs args) | |||
| : base(args.Inputs, args.Outputs, name: args.Name) | |||
| @@ -49,12 +51,13 @@ namespace Tensorflow.Keras.Engine | |||
| _auto_track_sub_layers = false; | |||
| _has_explicit_input_shape = false; | |||
| _is_graph_network = false; | |||
| _created_nodes = new List<INode>(); | |||
| // Add to the model any layers passed to the constructor. | |||
| if (args.Layers != null) | |||
| { | |||
| foreach (var layer in args.Layers) | |||
| add(layer as Layer); | |||
| add(layer); | |||
| } | |||
| } | |||
| @@ -118,7 +121,69 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| else | |||
| { | |||
| _self_tracked_trackables.add(layer); | |||
| _handle_deferred_layer_dependencies(layer); | |||
| } | |||
| } | |||
| void _handle_deferred_layer_dependencies(params ILayer[] layers) | |||
| { | |||
| _layers.AddRange(layers); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| { | |||
| if (!_has_explicit_input_shape) | |||
| { | |||
| _build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype); | |||
| } | |||
| if(_graph_initialized) | |||
| { | |||
| if (!built) | |||
| _init_graph_network(this.inputs, outputs); | |||
| return base.Call(inputs, state, training); | |||
| } | |||
| return base.Call(inputs, state, training); | |||
| } | |||
| void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) | |||
| { | |||
| ops.init_scope(); | |||
| var inputs = keras.Input(batch_input_shape: input_shape, | |||
| dtype: input_dtype, | |||
| name: $"{_layers[0].Name}_input"); | |||
| Tensors layer_input = inputs; | |||
| Tensors layer_output = null; | |||
| Tensors outputs = null; | |||
| foreach (var layer in _layers) | |||
| { | |||
| clear_previously_created_nodes(layer, _created_nodes); | |||
| layer_output = layer.Apply(layer_input); | |||
| // Keep track of nodes just created above | |||
| track_nodes_created_by_last_call(layer, _created_nodes); | |||
| layer_input = layer_output; | |||
| outputs = layer_output; | |||
| } | |||
| _init_graph_network(inputs, outputs); | |||
| _graph_initialized = true; | |||
| _inferred_input_shape = input_shape; | |||
| } | |||
| void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes) | |||
| { | |||
| } | |||
| void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes) | |||
| { | |||
| var node = layer.InboundNodes.Last(); | |||
| created_nodes.Add(node); | |||
| foreach(var prev_layer in node.iterate_inbound()) | |||
| { | |||
| created_nodes.add(prev_layer.Item1.OutboundNodes.Last()); | |||
| } | |||
| } | |||
| } | |||
| @@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Layers | |||
| var rank = inputs.rank; | |||
| if (rank > 2) | |||
| { | |||
| throw new NotImplementedException("call rank > 2"); | |||
| outputs = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { rank - 1 }, { 0 } }); | |||
| } | |||
| else | |||
| { | |||