Browse Source

add Keras model.summary().

tags/v0.30
Oceania2018 5 years ago
parent
commit
9f2adcfc77
5 changed files with 63 additions and 5 deletions
  1. +20
    -0
      src/TensorFlowNET.Core/Keras/Engine/Model.Summary.cs
  2. +18
    -0
      src/TensorFlowNET.Core/Keras/Engine/Node.IterateInbound.cs
  3. +6
    -0
      test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
  4. +17
    -0
      test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
  5. +2
    -5
      test/TensorFlowNET.UnitTest/PythonTest.cs

+ 20
- 0
src/TensorFlowNET.Core/Keras/Engine/Model.Summary.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Engine
{
public partial class Model
{
/// <summary>
/// Prints a string summary of the network.
/// </summary>
public void summary(int line_length = -1, float[] positions = null)
{
layer_utils.print_summary(this,
line_length: line_length,
positions: positions);
}
}
}

+ 18
- 0
src/TensorFlowNET.Core/Keras/Engine/Node.IterateInbound.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
public partial class Node
{
public IEnumerable<(Layer, int, int, Tensor)> iterate_inbound()
{
foreach(var kt in KerasInputs)
{
var (layer, node_index, tensor_index) = kt.KerasHistory;
yield return (layer, node_index, tensor_index, kt);
}
}
}
}

+ 6
- 0
test/TensorFlowNET.UnitTest/EagerModeTestBase.cs View File

@@ -2,12 +2,18 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
public class EagerModeTestBase : PythonTest public class EagerModeTestBase : PythonTest
{ {
protected KerasApi keras = tf.keras;
protected LayersApi layers = tf.keras.layers;

[TestInitialize] [TestInitialize]
public void TestInit() public void TestInit()
{ {


+ 17
- 0
test/TensorFlowNET.UnitTest/Keras/LayersTest.cs View File

@@ -16,6 +16,7 @@ namespace TensorFlowNET.UnitTest.Keras
[TestClass] [TestClass]
public class LayersTest : EagerModeTestBase public class LayersTest : EagerModeTestBase
{ {
[TestMethod] [TestMethod]
public void Sequential() public void Sequential()
{ {
@@ -23,6 +24,22 @@ namespace TensorFlowNET.UnitTest.Keras
model.add(tf.keras.Input(shape: 16)); model.add(tf.keras.Input(shape: 16));
} }


[TestMethod]
public void Functional()
{
var inputs = keras.Input(shape: 784);
Assert.AreEqual((None, 784), inputs.TensorShape);

var dense = layers.Dense(64, activation: "relu");
var x = dense.Apply(inputs);

x = layers.Dense(64, activation: "relu").Apply(x);
var outputs = layers.Dense(10).Apply(x);

var model = keras.Model(inputs, outputs, name: "mnist_model");
model.summary();
}

/// <summary> /// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
/// </summary> /// </summary>


+ 2
- 5
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -16,10 +16,7 @@ namespace TensorFlowNET.UnitTest
{ {
#region python compatibility layer #region python compatibility layer
protected PythonTest self { get => this; } protected PythonTest self { get => this; }
protected object None
{
get { return null; }
}
protected int None => -1;
#endregion #endregion


#region pytest assertions #region pytest assertions
@@ -150,7 +147,7 @@ namespace TensorFlowNET.UnitTest


protected object _eval_tensor(object tensor) protected object _eval_tensor(object tensor)
{ {
if (tensor == None)
if (tensor == null)
return None; return None;
//else if (callable(tensor)) //else if (callable(tensor))
// return self._eval_helper(tensor()) // return self._eval_helper(tensor())


Loading…
Cancel
Save