Browse Source

Merge pull request #259 from arnavdas88/master

Master
tags/v0.9
Haiping GitHub 6 years ago
parent
commit
bf894e1a18
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 20 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  2. +19
    -4
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  3. +15
    -15
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -190,7 +190,7 @@ namespace Tensorflow.Keras.Layers
var variable = _add_variable_with_custom_getter(name, var variable = _add_variable_with_custom_getter(name,
shape, shape,
dtype: dtype, dtype: dtype,
getter: getter, // getter == null ? base_layer_utils.make_variable : getter,
getter: (getter == null) ? base_layer_utils.make_variable : getter,
overwrite: true, overwrite: true,
initializer: initializer, initializer: initializer,
trainable: trainable.Value); trainable: trainable.Value);


+ 19
- 4
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -8,6 +8,21 @@ namespace Tensorflow.Keras.Utils
{ {
public class base_layer_utils public class base_layer_utils
{ {
/// <summary>
/// Adds a new variable to the layer.
/// </summary>
/// <param name="name"></param>
/// <param name="shape"></param>
/// <param name="dtype"></param>
/// <param name="initializer"></param>
/// <param name="trainable"></param>
/// <returns></returns>
public static RefVariable make_variable(string name,
int[] shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
bool trainable = true) => make_variable(name, shape, dtype, initializer, trainable, true);

/// <summary> /// <summary>
/// Adds a new variable to the layer. /// Adds a new variable to the layer.
/// </summary> /// </summary>
@@ -28,7 +43,7 @@ namespace Tensorflow.Keras.Utils


ops.init_scope(); ops.init_scope();


Func<Tensor> init_val = ()=> initializer.call(new TensorShape(shape), dtype: dtype);
Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);


var variable_dtype = dtype.as_base_dtype(); var variable_dtype = dtype.as_base_dtype();
var v = tf.Variable(init_val); var v = tf.Variable(init_val);
@@ -44,13 +59,13 @@ namespace Tensorflow.Keras.Utils
public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null, public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null,
string[] avoid_names = null, string @namespace = "", bool zero_based = false) string[] avoid_names = null, string @namespace = "", bool zero_based = false)
{ {
if(name_uid_map == null)
if (name_uid_map == null)
name_uid_map = get_default_graph_uid_map(); name_uid_map = get_default_graph_uid_map();
if (avoid_names == null) if (avoid_names == null)
avoid_names = new string[0]; avoid_names = new string[0];


string proposed_name = null; string proposed_name = null;
while(proposed_name == null || avoid_names.Contains(proposed_name))
while (proposed_name == null || avoid_names.Contains(proposed_name))
{ {
var name_key = (@namespace, name); var name_key = (@namespace, name);
if (!name_uid_map.ContainsKey(name_key)) if (!name_uid_map.ContainsKey(name_key))
@@ -58,7 +73,7 @@ namespace Tensorflow.Keras.Utils


if (zero_based) if (zero_based)
{ {
int number = name_uid_map[name_key];
int number = name_uid_map[name_key];
if (number > 0) if (number > 0)
proposed_name = $"{name}_{number}"; proposed_name = $"{name}_{number}";
else else


+ 15
- 15
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -14,21 +14,21 @@ namespace TensorFlowNET.ExamplesTests
public void BasicOperations() public void BasicOperations()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new BasicOperations() { Enabled = true }.Train();
new BasicOperations() { Enabled = true }.Run();
} }
[TestMethod] [TestMethod]
public void HelloWorld() public void HelloWorld()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new HelloWorld() { Enabled = true }.Train();
new HelloWorld() { Enabled = true }.Run();
} }
[TestMethod] [TestMethod]
public void ImageRecognition() public void ImageRecognition()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new HelloWorld() { Enabled = true }.Train();
new HelloWorld() { Enabled = true }.Run();
} }
[Ignore] [Ignore]
@@ -36,28 +36,28 @@ namespace TensorFlowNET.ExamplesTests
public void InceptionArchGoogLeNet() public void InceptionArchGoogLeNet()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new InceptionArchGoogLeNet() { Enabled = true }.Train();
new InceptionArchGoogLeNet() { Enabled = true }.Run();
} }
[TestMethod] [TestMethod]
public void KMeansClustering() public void KMeansClustering()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Train();
new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
} }
[TestMethod] [TestMethod]
public void LinearRegression() public void LinearRegression()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new LinearRegression() { Enabled = true }.Train();
new LinearRegression() { Enabled = true }.Run();
} }
[TestMethod] [TestMethod]
public void LogisticRegression() public void LogisticRegression()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Train();
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run();
} }
[Ignore] [Ignore]
@@ -65,7 +65,7 @@ namespace TensorFlowNET.ExamplesTests
public void NaiveBayesClassifier() public void NaiveBayesClassifier()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new NaiveBayesClassifier() { Enabled = false }.Train();
new NaiveBayesClassifier() { Enabled = false }.Run();
} }
[Ignore] [Ignore]
@@ -73,14 +73,14 @@ namespace TensorFlowNET.ExamplesTests
public void NamedEntityRecognition() public void NamedEntityRecognition()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new NamedEntityRecognition() { Enabled = true }.Train();
new NamedEntityRecognition() { Enabled = true }.Run();
} }
[TestMethod] [TestMethod]
public void NearestNeighbor() public void NearestNeighbor()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Train();
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
} }
[Ignore] [Ignore]
@@ -88,7 +88,7 @@ namespace TensorFlowNET.ExamplesTests
public void TextClassificationTrain() public void TextClassificationTrain()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Train();
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run();
} }
[Ignore] [Ignore]
@@ -96,21 +96,21 @@ namespace TensorFlowNET.ExamplesTests
public void TextClassificationWithMovieReviews() public void TextClassificationWithMovieReviews()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new BinaryTextClassification() { Enabled = true }.Train();
new BinaryTextClassification() { Enabled = true }.Run();
} }
[TestMethod] [TestMethod]
public void NeuralNetXor() public void NeuralNetXor()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Train());
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Run());
} }
[TestMethod] [TestMethod]
public void NeuralNetXor_ImportedGraph() public void NeuralNetXor_ImportedGraph()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Train());
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Run());
} }
@@ -118,7 +118,7 @@ namespace TensorFlowNET.ExamplesTests
public void ObjectDetection() public void ObjectDetection()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Train());
Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Run());
} }
} }
} }

Loading…
Cancel
Save