| @@ -78,7 +78,7 @@ namespace Tensorflow.Contexts | |||||
| if (args.GetGradientAttrs == null) | if (args.GetGradientAttrs == null) | ||||
| { | { | ||||
| attrs = new Dictionary<string, object>(); | attrs = new Dictionary<string, object>(); | ||||
| attrs["T"] = op.get_attr<TF_DataType>("T"); | |||||
| attrs["T"] = op.dtype; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -11,6 +11,7 @@ namespace Tensorflow.Keras.Engine | |||||
| ILayer Layer { get; } | ILayer Layer { get; } | ||||
| List<Tensor> KerasInputs { get; set; } | List<Tensor> KerasInputs { get; set; } | ||||
| INode[] ParentNodes { get; } | INode[] ParentNodes { get; } | ||||
| ILayer[] InboundLayers { get; } | |||||
| IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); | IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); | ||||
| bool is_input { get; } | bool is_input { get; } | ||||
| List<NodeConfig> serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map); | List<NodeConfig> serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map); | ||||
| @@ -88,6 +88,13 @@ namespace Tensorflow.NumPy | |||||
| public static bool Equals(Shape shape, object target) | public static bool Equals(Shape shape, object target) | ||||
| { | { | ||||
| if (shape is null && target is null) | |||||
| return true; | |||||
| else if (shape is null && target is not null) | |||||
| return false; | |||||
| else if (shape is not null && target is null) | |||||
| return false; | |||||
| switch (target) | switch (target) | ||||
| { | { | ||||
| case Shape shape1: | case Shape shape1: | ||||
| @@ -253,5 +253,11 @@ namespace Tensorflow | |||||
| public override bool Equals(object obj) => ShapeHelper.Equals(this, obj); | public override bool Equals(object obj) => ShapeHelper.Equals(this, obj); | ||||
| public override string ToString() => ShapeHelper.ToString(this); | public override string ToString() => ShapeHelper.ToString(this); | ||||
| public static bool operator ==(Shape a, Shape b) | |||||
| => ShapeHelper.Equals(a, b); | |||||
| public static bool operator !=(Shape a, Shape b) | |||||
| => !ShapeHelper.Equals(a, b); | |||||
| } | } | ||||
| } | } | ||||
| @@ -55,7 +55,7 @@ namespace Tensorflow | |||||
| public int _id_value { get; set; } | public int _id_value { get; set; } | ||||
| public Operation op => this; | public Operation op => this; | ||||
| public TF_DataType dtype => TF_DataType.DtInvalid; | |||||
| public TF_DataType dtype => output.dtype; | |||||
| public virtual string name => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationName(_handle)); | public virtual string name => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationName(_handle)); | ||||
| public string OpType => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | public string OpType => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | ||||
| @@ -69,7 +69,8 @@ namespace Tensorflow.Keras.Engine | |||||
| NetworkNodes = nodes; | NetworkNodes = nodes; | ||||
| NodesByDepth = nodes_by_depth; | NodesByDepth = nodes_by_depth; | ||||
| _layers = layers; | |||||
| if (_layers.Count == 0) | |||||
| _layers = layers; | |||||
| // Build self.input_names and self.output_names. | // Build self.input_names and self.output_names. | ||||
| _set_output_names(); | _set_output_names(); | ||||
| @@ -1,9 +1,13 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| public partial class Node | public partial class Node | ||||
| { | { | ||||
| public ILayer[] InboundLayers | |||||
| => iterate_inbound().Select(x => x.Item1).ToArray(); | |||||
| public IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound() | public IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound() | ||||
| { | { | ||||
| foreach (var kt in KerasInputs) | foreach (var kt in KerasInputs) | ||||
| @@ -150,6 +150,9 @@ namespace Tensorflow.Keras.Engine | |||||
| void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) | void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) | ||||
| { | { | ||||
| if (_inferred_input_shape == input_shape) | |||||
| return; | |||||
| ops.init_scope(); | ops.init_scope(); | ||||
| var inputs = keras.Input(batch_input_shape: input_shape, | var inputs = keras.Input(batch_input_shape: input_shape, | ||||
| dtype: input_dtype, | dtype: input_dtype, | ||||
| @@ -157,16 +160,17 @@ namespace Tensorflow.Keras.Engine | |||||
| Tensors layer_input = inputs; | Tensors layer_input = inputs; | ||||
| Tensors layer_output = null; | Tensors layer_output = null; | ||||
| Tensors outputs = null; | Tensors outputs = null; | ||||
| List<INode> created_nodes = new List<INode>(); | |||||
| foreach (var layer in _layers) | foreach (var layer in _layers) | ||||
| { | { | ||||
| clear_previously_created_nodes(layer, _created_nodes); | clear_previously_created_nodes(layer, _created_nodes); | ||||
| layer_output = layer.Apply(layer_input); | layer_output = layer.Apply(layer_input); | ||||
| // Keep track of nodes just created above | // Keep track of nodes just created above | ||||
| track_nodes_created_by_last_call(layer, _created_nodes); | |||||
| track_nodes_created_by_last_call(layer, created_nodes); | |||||
| layer_input = layer_output; | layer_input = layer_output; | ||||
| outputs = layer_output; | outputs = layer_output; | ||||
| } | } | ||||
| _created_nodes = created_nodes; | |||||
| _init_graph_network(inputs, outputs); | _init_graph_network(inputs, outputs); | ||||
| _graph_initialized = true; | _graph_initialized = true; | ||||
| _inferred_input_shape = input_shape; | _inferred_input_shape = input_shape; | ||||
| @@ -174,16 +178,28 @@ namespace Tensorflow.Keras.Engine | |||||
| void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes) | void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes) | ||||
| { | { | ||||
| foreach(var node in layer.InboundNodes) | |||||
| { | |||||
| foreach(var prev_layer in node.InboundLayers) | |||||
| { | |||||
| var outNodes = prev_layer.OutboundNodes.Where(x => !created_nodes.Contains(x)).ToArray(); | |||||
| prev_layer.OutboundNodes.Clear(); | |||||
| prev_layer.OutboundNodes.AddRange(outNodes); | |||||
| } | |||||
| } | |||||
| var inNodes = layer.InboundNodes.Where(x => !created_nodes.Contains(x)).ToArray(); | |||||
| layer.InboundNodes.Clear(); | |||||
| layer.InboundNodes.AddRange(inNodes); | |||||
| } | } | ||||
| void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes) | void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes) | ||||
| { | { | ||||
| var node = layer.InboundNodes.Last(); | var node = layer.InboundNodes.Last(); | ||||
| created_nodes.Add(node); | created_nodes.Add(node); | ||||
| foreach(var prev_layer in node.iterate_inbound()) | |||||
| foreach(var prev_layer in node.InboundLayers) | |||||
| { | { | ||||
| created_nodes.add(prev_layer.Item1.OutboundNodes.Last()); | |||||
| created_nodes.add(prev_layer.OutboundNodes.Last()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||