From 9f2adcfc77f7d164e137af47887067c934fdfc1f Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 24 Oct 2020 08:26:21 -0500 Subject: [PATCH] add Keras model.summary(). --- .../Keras/Engine/Model.Summary.cs | 20 +++++++++++++++++++ .../Keras/Engine/Node.IterateInbound.cs | 18 +++++++++++++++++ .../EagerModeTestBase.cs | 6 ++++++ .../Keras/LayersTest.cs | 17 ++++++++++++++++ test/TensorFlowNET.UnitTest/PythonTest.cs | 7 ++----- 5 files changed, 63 insertions(+), 5 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Engine/Model.Summary.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/Node.IterateInbound.cs diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Summary.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Summary.cs new file mode 100644 index 00000000..97dde0aa --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Summary.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + /// + /// Prints a string summary of the network. + /// + public void summary(int line_length = -1, float[] positions = null) + { + layer_utils.print_summary(this, + line_length: line_length, + positions: positions); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Node.IterateInbound.cs b/src/TensorFlowNET.Core/Keras/Engine/Node.IterateInbound.cs new file mode 100644 index 00000000..6a2ddf22 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Node.IterateInbound.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + public partial class Node + { + public IEnumerable<(Layer, int, int, Tensor)> iterate_inbound() + { + foreach(var kt in KerasInputs) + { + var (layer, node_index, tensor_index) = kt.KerasHistory; + yield return (layer, node_index, tensor_index, kt); + } + } + } +} diff --git a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs index 9ca58ab0..4e837aa3 100644 --- a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs +++ b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs @@ -2,12 +2,18 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest { public class EagerModeTestBase : PythonTest { + protected KerasApi keras = tf.keras; + protected LayersApi layers = tf.keras.layers; + [TestInitialize] public void TestInit() { diff --git a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs index 0b21694f..a37f0fd9 100644 --- a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs @@ -16,6 +16,7 @@ namespace TensorFlowNET.UnitTest.Keras [TestClass] public class LayersTest : EagerModeTestBase { + [TestMethod] public void Sequential() { @@ -23,6 +24,22 @@ namespace TensorFlowNET.UnitTest.Keras model.add(tf.keras.Input(shape: 16)); } + [TestMethod] + public void Functional() + { + var inputs = keras.Input(shape: 784); + Assert.AreEqual((None, 784), inputs.TensorShape); + + var dense = layers.Dense(64, activation: "relu"); + var x = dense.Apply(inputs); + + x = layers.Dense(64, activation: "relu").Apply(x); + var outputs = layers.Dense(10).Apply(x); + + var model = keras.Model(inputs, outputs, name: "mnist_model"); + model.summary(); + } + /// /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding /// diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index cf908fa2..9ce995b4 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -16,10 +16,7 @@ namespace TensorFlowNET.UnitTest { #region python compatibility layer protected PythonTest self { get => this; } - protected object None - { - get { return null; } - } + protected int None => -1; #endregion #region pytest assertions @@ -150,7 +147,7 @@ namespace TensorFlowNET.UnitTest protected object _eval_tensor(object tensor) { - if (tensor == None) + if (tensor == null) return None; //else if (callable(tensor)) // return self._eval_helper(tensor())