| @@ -0,0 +1,20 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.Utils; | |||||
| namespace Tensorflow.Keras.Engine | |||||
| { | |||||
| public partial class Model | |||||
| { | |||||
| /// <summary> | |||||
| /// Prints a string summary of the network. | |||||
| /// </summary> | |||||
| public void summary(int line_length = -1, float[] positions = null) | |||||
| { | |||||
| layer_utils.print_summary(this, | |||||
| line_length: line_length, | |||||
| positions: positions); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,18 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Engine | |||||
| { | |||||
| public partial class Node | |||||
| { | |||||
| public IEnumerable<(Layer, int, int, Tensor)> iterate_inbound() | |||||
| { | |||||
| foreach(var kt in KerasInputs) | |||||
| { | |||||
| var (layer, node_index, tensor_index) = kt.KerasHistory; | |||||
| yield return (layer, node_index, tensor_index, kt); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,12 +2,18 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Layers; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| public class EagerModeTestBase : PythonTest | public class EagerModeTestBase : PythonTest | ||||
| { | { | ||||
| protected KerasApi keras = tf.keras; | |||||
| protected LayersApi layers = tf.keras.layers; | |||||
| [TestInitialize] | [TestInitialize] | ||||
| public void TestInit() | public void TestInit() | ||||
| { | { | ||||
| @@ -16,6 +16,7 @@ namespace TensorFlowNET.UnitTest.Keras | |||||
| [TestClass] | [TestClass] | ||||
| public class LayersTest : EagerModeTestBase | public class LayersTest : EagerModeTestBase | ||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void Sequential() | public void Sequential() | ||||
| { | { | ||||
| @@ -23,6 +24,22 @@ namespace TensorFlowNET.UnitTest.Keras | |||||
| model.add(tf.keras.Input(shape: 16)); | model.add(tf.keras.Input(shape: 16)); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Functional() | |||||
| { | |||||
| var inputs = keras.Input(shape: 784); | |||||
| Assert.AreEqual((None, 784), inputs.TensorShape); | |||||
| var dense = layers.Dense(64, activation: "relu"); | |||||
| var x = dense.Apply(inputs); | |||||
| x = layers.Dense(64, activation: "relu").Apply(x); | |||||
| var outputs = layers.Dense(10).Apply(x); | |||||
| var model = keras.Model(inputs, outputs, name: "mnist_model"); | |||||
| model.summary(); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding | /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding | ||||
| /// </summary> | /// </summary> | ||||
| @@ -16,10 +16,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| #region python compatibility layer | #region python compatibility layer | ||||
| protected PythonTest self { get => this; } | protected PythonTest self { get => this; } | ||||
| protected object None | |||||
| { | |||||
| get { return null; } | |||||
| } | |||||
| protected int None => -1; | |||||
| #endregion | #endregion | ||||
| #region pytest assertions | #region pytest assertions | ||||
| @@ -150,7 +147,7 @@ namespace TensorFlowNET.UnitTest | |||||
| protected object _eval_tensor(object tensor) | protected object _eval_tensor(object tensor) | ||||
| { | { | ||||
| if (tensor == None) | |||||
| if (tensor == null) | |||||
| return None; | return None; | ||||
| //else if (callable(tensor)) | //else if (callable(tensor)) | ||||
| // return self._eval_helper(tensor()) | // return self._eval_helper(tensor()) | ||||