Browse Source

Reduce the time of keras unittest.

pull/999/head
Yaohui Liu 2 years ago
parent
commit
f22876aca0
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
3 changed files with 10 additions and 10 deletions
  1. +0
    -0
      test/TensorFlowNET.Keras.UnitTest/GradientTest.cs
  2. +1
    -1
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
  3. +9
    -9
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs

test/TensorFlowNET.Keras.UnitTest/Gradient.cs → test/TensorFlowNET.Keras.UnitTest/GradientTest.cs View File


+ 1
- 1
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -43,7 +43,7 @@ public class SequentialModelLoad
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 50000,
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);


+ 9
- 9
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs View File

@@ -18,15 +18,15 @@ public class SequentialModelSave
[TestMethod]
public void SimpleModelFromAutoCompile()
{
var inputs = tf.keras.layers.Input((28, 28, 1));
var x = tf.keras.layers.Flatten().Apply(inputs);
x = tf.keras.layers.Dense(100, activation: tf.nn.relu).Apply(x);
x = tf.keras.layers.Dense(units: 10).Apply(x);
var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x);
var model = tf.keras.Model(inputs, outputs);
var inputs = keras.layers.Input((28, 28, 1));
var x = keras.layers.Flatten().Apply(inputs);
x = keras.layers.Dense(100, activation: tf.nn.relu).Apply(x);
x = keras.layers.Dense(units: 10).Apply(x);
var outputs = keras.layers.Softmax(axis: 1).Apply(x);
var model = keras.Model(inputs, outputs);

model.compile(new Adam(0.001f),
tf.keras.losses.SparseCategoricalCrossentropy(),
keras.losses.SparseCategoricalCrossentropy(),
new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
@@ -37,7 +37,7 @@ public class SequentialModelSave
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 10000,
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
@@ -69,7 +69,7 @@ public class SequentialModelSave
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 50000,
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);


Loading…
Cancel
Save