diff --git a/README.md b/README.md
index 42cd3ec4..7a861d04 100644
--- a/README.md
+++ b/README.md
@@ -66,13 +66,14 @@ using(var sess = tf.Session())
Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflownet.readthedocs.io/en/latest/FrontCover.html).
-More examples:
+### More examples:
* [Hello World](test/TensorFlowNET.Examples/HelloWorld.cs)
* [Basic Operations](test/TensorFlowNET.Examples/BasicOperations.cs)
* [Image Recognition](test/TensorFlowNET.Examples/ImageRecognition.cs)
* [Linear Regression](test/TensorFlowNET.Examples/LinearRegression.cs)
* [Text Classification](test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs)
+* [CNN Text Classification](test/TensorFlowNET.Examples/CnnTextClassification.cs)
* [Naive Bayes Classification](test/TensorFlowNET.Examples/NaiveBayesClassifier.cs)
* [Named Entity Recognition](test/TensorFlowNET.Examples/NamedEntityRecognition.cs)
diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index 24afc77e..25a97e6d 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -9,11 +9,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Utility", "src\TensorFlowNET.Utility\TensorFlowNET.Utility.csproj", "{00D9085C-0FC7-453C-A0CC-BAD98F44FEA0}"
-EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{268BF0B6-0AA9-4FD3-A245-7AF336F1E3E9}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{E8340C61-12C1-4BEE-A340-403E7C1ACD82}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "scikit-learn", "..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj", "{199DDAD8-4A6F-43B3-A560-C0393619E304}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -33,18 +33,18 @@ Global
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU
- {00D9085C-0FC7-453C-A0CC-BAD98F44FEA0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {00D9085C-0FC7-453C-A0CC-BAD98F44FEA0}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {00D9085C-0FC7-453C-A0CC-BAD98F44FEA0}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {00D9085C-0FC7-453C-A0CC-BAD98F44FEA0}.Release|Any CPU.Build.0 = Release|Any CPU
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.Build.0 = Debug|Any CPU
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.Build.0 = Release|Any CPU
- {268BF0B6-0AA9-4FD3-A245-7AF336F1E3E9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {268BF0B6-0AA9-4FD3-A245-7AF336F1E3E9}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {268BF0B6-0AA9-4FD3-A245-7AF336F1E3E9}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {268BF0B6-0AA9-4FD3-A245-7AF336F1E3E9}.Release|Any CPU.Build.0 = Release|Any CPU
+ {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.Build.0 = Release|Any CPU
+ {199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs
index 584f3b01..19876d62 100644
--- a/src/TensorFlowNET.Core/APIs/tf.init.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.init.cs
@@ -7,7 +7,24 @@ namespace Tensorflow
public static partial class tf
{
public static IInitializer zeros_initializer => new Zeros();
+ public static IInitializer glorot_uniform_initializer => new GlorotUniform();
+ public static variable_scope variable_scope(string name,
+ string default_name = null,
+ object values = null,
+ bool auxiliary_name_scope = true) => new variable_scope(name,
+ default_name,
+ values,
+ auxiliary_name_scope);
+
+ public static variable_scope variable_scope(VariableScope scope,
+ string default_name = null,
+ object values = null,
+ bool auxiliary_name_scope = true) => new variable_scope(scope,
+ default_name,
+ values,
+ auxiliary_name_scope);
+
public class Zeros : IInitializer
{
private TF_DataType dtype;
@@ -30,5 +47,105 @@ namespace Tensorflow
return new { dtype = dtype.name() };
}
}
+
+ ///
+ /// Initializer capable of adapting its scale to the shape of weights tensors.
+ ///
+ public class VarianceScaling : IInitializer
+ {
+ protected float _scale;
+ protected string _mode;
+ protected string _distribution;
+ protected int? _seed;
+ protected TF_DataType _dtype;
+
+ public VarianceScaling(float scale = 1.0f,
+ string mode = "fan_in",
+ string distribution= "truncated_normal",
+ int? seed = null,
+ TF_DataType dtype = TF_DataType.TF_FLOAT)
+ {
+ if (scale < 0)
+ throw new ValueError("`scale` must be positive float.");
+ _scale = scale;
+ _mode = mode;
+ _distribution = distribution;
+ _seed = seed;
+ _dtype = dtype;
+ }
+
+ public Tensor call(TensorShape shape, TF_DataType dtype)
+ {
+ var (fan_in, fan_out) = _compute_fans(shape);
+ if (_mode == "fan_in")
+ _scale /= Math.Max(1, fan_in);
+ else if (_mode == "fan_out")
+ _scale /= Math.Max(1, fan_out);
+ else
+ _scale /= Math.Max(1, (fan_in + fan_out) / 2);
+
+ if (_distribution == "normal" || _distribution == "truncated_normal")
+ {
+ throw new NotImplementedException("truncated_normal");
+ }
+ else if(_distribution == "untruncated_normal")
+ {
+ throw new NotImplementedException("truncated_normal");
+ }
+ else
+ {
+ var limit = Math.Sqrt(3.0f * _scale);
+ return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
+ }
+ }
+
+ private (int, int) _compute_fans(int[] shape)
+ {
+ if (shape.Length < 1)
+ return (1, 1);
+ if (shape.Length == 1)
+ return (shape[0], shape[0]);
+ if (shape.Length == 2)
+ return (shape[0], shape[1]);
+ else
+ throw new NotImplementedException("VarianceScaling._compute_fans");
+ }
+
+ public virtual object get_config()
+ {
+ return new
+ {
+ scale = _scale,
+ mode = _mode,
+ distribution = _distribution,
+ seed = _seed,
+ dtype = _dtype
+ };
+ }
+ }
+
+ public class GlorotUniform : VarianceScaling
+ {
+ public GlorotUniform(float scale = 1.0f,
+ string mode = "fan_avg",
+ string distribution = "uniform",
+ int? seed = null,
+ TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
+ {
+
+ }
+
+ public object get_config()
+ {
+ return new
+ {
+ scale = _scale,
+ mode = _mode,
+ distribution = _distribution,
+ seed = _seed,
+ dtype = _dtype
+ };
+ }
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Contrib/Learn/Preprocessing/VocabularyProcessor.cs b/src/TensorFlowNET.Core/Contrib/Learn/Preprocessing/VocabularyProcessor.cs
new file mode 100644
index 00000000..00f2c9b0
--- /dev/null
+++ b/src/TensorFlowNET.Core/Contrib/Learn/Preprocessing/VocabularyProcessor.cs
@@ -0,0 +1,19 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Contrib.Learn.Preprocessing
+{
+ public class VocabularyProcessor
+ {
+ private int _max_document_length;
+ private int _min_frequency;
+
+ public VocabularyProcessor(int max_document_length,
+ int min_frequency)
+ {
+ _max_document_length = max_document_length;
+ _min_frequency = min_frequency;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Framework/importer.py.cs b/src/TensorFlowNET.Core/Framework/importer.py.cs
index a0a45ca0..9070d27b 100644
--- a/src/TensorFlowNET.Core/Framework/importer.py.cs
+++ b/src/TensorFlowNET.Core/Framework/importer.py.cs
@@ -7,7 +7,7 @@ using static Tensorflow.OpDef.Types;
namespace Tensorflow
{
- public class importer
+ public class importer : Python
{
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
Dictionary input_map = null,
@@ -26,7 +26,7 @@ namespace Tensorflow
string prefix = "";
var graph = ops.get_default_graph();
- Python.with(new ops.name_scope(name, "import", input_map.Values), scope =>
+ with(new ops.name_scope(name, "import", input_map.Values), scope =>
{
prefix = scope;
/*if (!string.IsNullOrEmpty(prefix))
diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
index e4993b5e..a4af9caa 100644
--- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
+++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
@@ -7,7 +7,7 @@ using System.Threading;
namespace Tensorflow
{
- public class gradients_impl
+ public class gradients_impl : Python
{
public static Tensor[] gradients(Tensor[] ys,
Tensor[] xs,
@@ -58,7 +58,7 @@ namespace Tensorflow
**/
var grads = new Dictionary();
- Python.with(new ops.name_scope(name, "gradients", values: all), scope =>
+ with(new ops.name_scope(name, "gradients", values: all), scope =>
{
string grad_scope = scope;
// Get a uid for this call to gradients that can be used to help
@@ -131,7 +131,7 @@ namespace Tensorflow
// for ops that do not have gradients.
var grad_fn = ops.get_gradient_function(op);
- Python.with(new ops.name_scope(op.name + "_grad"), scope1 =>
+ with(new ops.name_scope(op.name + "_grad"), scope1 =>
{
string name1 = scope1;
if (grad_fn != null)
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index a025bfc6..5caaf6b0 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -196,11 +196,11 @@ namespace Tensorflow
_create_op_helper(op, true);
- Console.Write($"create_op: {op_type} '{node_def.Name}'");
+ /*Console.Write($"create_op: {op_type} '{node_def.Name}'");
Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
- Console.WriteLine();
+ Console.WriteLine();*/
return op;
}
diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
index e45074b0..02d7f180 100644
--- a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
+++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
@@ -12,7 +12,7 @@ namespace Tensorflow
string scope = "",
string loss_collection= "losses")
{
- with(new ops.name_scope(scope,
+ with(new ops.name_scope(scope,
"sparse_softmax_cross_entropy_loss",
(logits, labels, weights)),
namescope =>
diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
index 0c19dea0..8e6bfba0 100644
--- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
+++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
@@ -10,7 +10,7 @@ using static Tensorflow.OpDef.Types;
namespace Tensorflow
{
- public class OpDefLibrary
+ public class OpDefLibrary : Python
{
public Operation _apply_op_helper(string op_type_name, string name = null, dynamic args = null)
{
@@ -44,7 +44,7 @@ namespace Tensorflow
var input_types = new List();
dynamic values = null;
- return Python.with(new ops.name_scope(name), scope =>
+ return with(new ops.name_scope(name), scope =>
{
var inferred_from = new Dictionary();
var base_types = new List();
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index 66574dae..18de83dd 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -5,14 +5,14 @@ using System.Text;
namespace Tensorflow
{
- public class array_ops
+ public class array_ops : Python
{
public static Tensor placeholder_with_default(T input, int[] shape, string name = null) => gen_array_ops.placeholder_with_default(input, shape, name);
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
- return Python.with(new ops.name_scope(name, "zeros", shape), scope =>
+ return with(new ops.name_scope(name, "zeros", shape), scope =>
{
name = scope;
switch (dtype)
@@ -68,7 +68,7 @@ namespace Tensorflow
private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true)
{
- return Python.with(new ops.name_scope(name, "ones_like", new { tensor }), scope =>
+ return with(new ops.name_scope(name, "ones_like", new { tensor }), scope =>
{
name = scope;
var tensor1 = ops.convert_to_tensor(tensor, name: "tensor");
@@ -84,7 +84,7 @@ namespace Tensorflow
public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
- return Python.with(new ops.name_scope(name, "ones", new { shape }), scope =>
+ return with(new ops.name_scope(name, "ones", new { shape }), scope =>
{
name = scope;
var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name);
@@ -130,7 +130,7 @@ namespace Tensorflow
private static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
{
- return Python.with(new ops.name_scope(name, "Shape", new { input }), scope =>
+ return with(new ops.name_scope(name, "Shape", new { input }), scope =>
{
name = scope;
@@ -151,7 +151,7 @@ namespace Tensorflow
private static Tensor size_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
{
- return Python.with(new ops.name_scope(name, "Size", new Tensor[] { input }), scope =>
+ return with(new ops.name_scope(name, "Size", new Tensor[] { input }), scope =>
{
name = scope;
@@ -182,7 +182,7 @@ namespace Tensorflow
public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{
- return Python.with(new ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope =>
+ return with(new ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope =>
{
name = scope;
tensor = ops.convert_to_tensor(tensor, name: "tensor");
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
index c9220868..303104b5 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -9,7 +9,7 @@ namespace Tensorflow
{
public static Operation group(T[] inputs, string name = null) where T : ITensorOrOperation
{
- return with(new ops.name_scope(name, "group_deps", inputs), scope =>
+ return with(new ops.name_scope(name, "group_deps", inputs), scope =>
{
name = scope;
@@ -39,7 +39,7 @@ namespace Tensorflow
private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = null)
{
- return Python.with<_ControlDependenciesController, Operation>(ops.control_dependencies(deps), ctl =>
+ return with(ops.control_dependencies(deps), ctl =>
{
if (dev == null)
{
@@ -83,7 +83,7 @@ namespace Tensorflow
public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null)
{
- return Python.with(new ops.name_scope(name, "tuple", tensors), scope =>
+ return with(new ops.name_scope(name, "tuple", tensors), scope =>
{
name = scope;
var gating_ops = tensors.Select(x => x.op).ToList();
@@ -115,11 +115,11 @@ namespace Tensorflow
values.AddRange(dependencies);
values.Add(output_tensor);
- return Python.with(new ops.name_scope(name, "control_dependency", values), scope =>
+ return with(new ops.name_scope(name, "control_dependency", values), scope =>
{
name = scope;
- return Python.with<_ControlDependenciesController, Tensor>(ops.control_dependencies(dependencies), ctl =>
+ return with(ops.control_dependencies(dependencies), ctl =>
{
output_tensor = ops.convert_to_tensor_or_composite(output_tensor);
return _Identity(output_tensor, name: name);
diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs
index 8cf40fd1..354177f9 100644
--- a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs
@@ -24,10 +24,34 @@ namespace Tensorflow
if (!seed2.HasValue)
seed2 = 0;
- var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", name: name,
+ var _op = _op_def_lib._apply_op_helper("RandomStandardNormal",
+ name: name,
args: new { shape, dtype, seed, seed2 });
return _op.outputs[0];
}
+
+ ///
+ /// Outputs random values from a uniform distribution.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor random_uniform(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null)
+ {
+ if (!seed.HasValue)
+ seed = 0;
+ if (!seed2.HasValue)
+ seed2 = 0;
+
+ var _op = _op_def_lib._apply_op_helper("RandomUniform",
+ name: name,
+ args: new { shape, dtype, seed, seed2});
+
+ return _op.outputs[0];
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs
index b0015983..fb9d097a 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs
@@ -14,7 +14,7 @@ namespace Tensorflow
if(base_type == x.dtype)
return x;
- return with(new ops.name_scope(name, "Cast", new { x }), scope =>
+ return with(new ops.name_scope(name, "Cast", new { x }), scope =>
{
x = ops.convert_to_tensor(x, name: "x");
if (x.dtype.as_base_dtype() != base_type)
@@ -165,7 +165,7 @@ namespace Tensorflow
if (delta == null)
delta = 1;
- return with(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope =>
+ return with(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope =>
{
name = scope;
var start1 = ops.convert_to_tensor(start, name: "start");
@@ -178,7 +178,7 @@ namespace Tensorflow
public static Tensor floordiv(Tensor x, Tensor y, string name = null)
{
- return with(new ops.name_scope(name, "floordiv", new { x, y }), scope =>
+ return with(new ops.name_scope(name, "floordiv", new { x, y }), scope =>
{
return gen_math_ops.floor_div(x, y, scope);
});
@@ -186,7 +186,7 @@ namespace Tensorflow
public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true)
{
- return with(new ops.name_scope(name, "Rank", new List { input }), scope =>
+ return with(new ops.name_scope(name, "Rank", new List { input }), scope =>
{
name = scope;
var input_tensor = ops.convert_to_tensor(input);
@@ -206,7 +206,7 @@ namespace Tensorflow
{
Tensor result = null;
- Python.with(new ops.name_scope(name, "MatMul", new Tensor[] { a, b }), scope =>
+ with(new ops.name_scope(name, "MatMul", new Tensor[] { a, b }), scope =>
{
name = scope;
@@ -236,7 +236,7 @@ namespace Tensorflow
if (dt.is_floating() || dt.is_integer())
return x;
- return Python.with(new ops.name_scope(name, "Conj", new List { x }), scope =>
+ return with(new ops.name_scope(name, "Conj", new List { x }), scope =>
{
return x;
diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs
index 299ea3f7..c6ea3aa9 100644
--- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs
@@ -4,8 +4,18 @@ using System.Text;
namespace Tensorflow
{
- public class random_ops
+ public class random_ops : Python
{
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
public static Tensor random_normal(int[] shape,
float mean = 0.0f,
float stddev = 1.0f,
@@ -13,7 +23,7 @@ namespace Tensorflow
int? seed = null,
string name = null)
{
- return Python.with(new ops.name_scope(name, "random_normal", new object[] { shape, mean, stddev }), scope =>
+ return with(new ops.name_scope(name, "random_normal", new { shape, mean, stddev }), scope =>
{
var shape_tensor = _ShapeTensor(shape);
var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean");
@@ -26,6 +36,34 @@ namespace Tensorflow
});
}
+ ///
+ /// Outputs random values from a uniform distribution.
+ ///
+ ///
+ ///
+ ///
+ /// The type of the output
+ /// Used to create a random seed for the distribution.
+ /// A name for the operation
+ /// A tensor of the specified shape filled with random uniform values.
+ public static Tensor random_uniform(int[] shape,
+ float minval = 0,
+ float? maxval = null,
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ int? seed = null,
+ string name = null)
+ {
+ return with(new ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope =>
+ {
+ name = scope;
+ var tensorShape = _ShapeTensor(shape);
+ var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min");
+ var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max");
+ var rnd = gen_random_ops.random_uniform(tensorShape, dtype);
+ return math_ops.add(rnd * (maxTensor - minTensor), minTensor, name: name);
+ });
+ }
+
private static Tensor _ShapeTensor(int[] shape)
{
return ops.convert_to_tensor(shape, name: "shape");
diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs
index 8aa61222..b11f2889 100644
--- a/src/TensorFlowNET.Core/Python.cs
+++ b/src/TensorFlowNET.Core/Python.cs
@@ -43,12 +43,12 @@ namespace Tensorflow
}
}
- public static void with(IPython py, Action action) where T : IPython
+ public static void with(T py, Action action) where T : IPython
{
try
{
py.__enter__();
- action((T)py);
+ action(py);
}
catch (Exception ex)
{
@@ -62,12 +62,12 @@ namespace Tensorflow
}
}
- public static TOut with(IPython py, Func action) where TIn : IPython
+ public static TOut with(TIn py, Func action) where TIn : IPython
{
try
{
py.__enter__();
- return action((TIn)py);
+ return action(py);
}
catch (Exception ex)
{
diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
index 198cd509..5c84f34a 100644
--- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs
+++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
@@ -62,26 +62,29 @@ namespace Tensorflow
switch (subfeed_val)
{
- case IntPtr pointer:
- feed_dict_tensor[subfeed_t] = pointer;
+ case IntPtr val:
+ feed_dict_tensor[subfeed_t] = val;
break;
- case NDArray nd:
- feed_dict_tensor[subfeed_t] = nd;
+ case NDArray val:
+ feed_dict_tensor[subfeed_t] = val;
break;
- case float floatVal:
- feed_dict_tensor[subfeed_t] = (NDArray)floatVal;
+ case float val:
+ feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
- case double doubleVal:
- feed_dict_tensor[subfeed_t] = (NDArray)doubleVal;
+ case double val:
+ feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
- case int intVal:
- feed_dict_tensor[subfeed_t] = (NDArray)intVal;
+ case short val:
+ feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
- case string str:
- feed_dict_tensor[subfeed_t] = (NDArray)str;
+ case int val:
+ feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
- case byte[] bytes:
- feed_dict_tensor[subfeed_t] = (NDArray)bytes;
+ case string val:
+ feed_dict_tensor[subfeed_t] = (NDArray)val;
+ break;
+ case byte[] val:
+ feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
default:
Console.WriteLine($"can't handle data type of subfeed_val");
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 01448769..949166c1 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -4,7 +4,7 @@
netstandard2.0
TensorFlow.NET
Tensorflow
- 0.4.0
+ 0.4.2
Haiping Chen
SciSharp STACK
true
@@ -16,11 +16,11 @@
TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET
Google's TensorFlow binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io
- 0.4.0.0
- Added Linear Regression example.
-
+ 0.4.2.0
+ Added ConfigProto to control CPU and GPU resource.
+Fixed import name scope issue.
7.2
- 0.4.0.0
+ 0.4.2.0
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
index 89a8aa17..a357e5d5 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
@@ -42,7 +42,7 @@ namespace Tensorflow
dtype = tr.dtype.as_base_dtype();
var namescope = new ops.name_scope(null, name, new { x, y });
- return Python.with(namescope, scope =>
+ return with(namescope, scope =>
{
Tensor result = null;
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index e2dd55a5..2b3db534 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -12,7 +12,7 @@ namespace Tensorflow
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
///
- public partial class Tensor : IDisposable, ITensorOrOperation
+ public partial class Tensor : Python, IDisposable, ITensorOrOperation
{
private readonly IntPtr _handle;
@@ -77,6 +77,11 @@ namespace Tensorflow
return null;
}
+ public TensorShape getShape()
+ {
+ return tensor_util.to_shape(shape);
+ }
+
///
/// number of dimensions
/// 0 Scalar (magnitude only)
diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs
index 25c443f3..407315e2 100644
--- a/src/TensorFlowNET.Core/Train/Optimizer.cs
+++ b/src/TensorFlowNET.Core/Train/Optimizer.cs
@@ -12,7 +12,7 @@ namespace Tensorflow
/// class directly, but instead instantiate one of its subclasses such as
/// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
///
- public abstract class Optimizer
+ public abstract class Optimizer : Python
{
// Values for gate_gradients.
public static int GATE_NONE = 0;
@@ -87,7 +87,7 @@ namespace Tensorflow
_create_slots(var_list);
var update_ops = new List();
- return Python.with(new ops.name_scope(name, Name), scope =>
+ return with(new ops.name_scope(name, Name), scope =>
{
name = scope;
_prepare();
@@ -98,7 +98,7 @@ namespace Tensorflow
continue;
var scope_name = var.op.name;
- Python.with(new ops.name_scope("update_" + scope_name), scope2 =>
+ with(new ops.name_scope("update_" + scope_name), scope2 =>
{
update_ops.Add(processor.update_op(this, grad));
});
diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
index 06a1fb49..876bf856 100644
--- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
@@ -5,7 +5,7 @@ using System.Text;
namespace Tensorflow
{
- public class BaseSaverBuilder
+ public class BaseSaverBuilder : Python
{
protected SaverDef.Types.CheckpointFormatVersion _write_version;
@@ -79,7 +79,7 @@ namespace Tensorflow
Tensor save_tensor = null;
Operation restore_op = null;
- return Python.with(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope =>
+ return with(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope =>
{
name = scope;
diff --git a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs
new file mode 100644
index 00000000..6f97b19d
--- /dev/null
+++ b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs
@@ -0,0 +1,86 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class PureVariableScope : IPython
+ {
+ private string _name;
+ private VariableScope _scope;
+ private string _new_name;
+ private string _old_name_scope;
+ private bool _reuse;
+ private _VariableStore _var_store;
+ private VariableScope _old;
+ private _VariableScopeStore _var_scope_store;
+ private VariableScope variable_scope_object;
+ private VariableScope _cached_variable_scope_object;
+
+ public PureVariableScope(string name,
+ string old_name_scope = null,
+ TF_DataType dtype = TF_DataType.DtInvalid)
+ {
+ _name = name;
+ _old_name_scope = old_name_scope;
+ _var_store = variable_scope._get_default_variable_store();
+ _var_scope_store = variable_scope.get_variable_scope_store();
+ }
+
+ public PureVariableScope(VariableScope scope,
+ string old_name_scope = null,
+ TF_DataType dtype = TF_DataType.DtInvalid)
+ {
+ _scope = scope;
+ _old_name_scope = old_name_scope;
+ _var_store = variable_scope._get_default_variable_store();
+ _var_scope_store = variable_scope.get_variable_scope_store();
+ _new_name = _scope._name;
+
+ string name_scope = _scope._name_scope;
+ variable_scope_object = new VariableScope(_reuse,
+ name: _new_name,
+ name_scope: name_scope);
+
+ _cached_variable_scope_object = variable_scope_object;
+ }
+
+ public void __enter__()
+ {
+ _old = _var_scope_store.current_scope;
+ if(_scope != null)
+ {
+ _var_scope_store.open_variable_scope(_new_name);
+ variable_scope_object = _cached_variable_scope_object;
+ }
+ else
+ {
+ _new_name = string.IsNullOrEmpty(_old._name) ? _name : _old._name + "/" + _name;
+ _reuse = _reuse || _old.resue;
+ string name_scope = _old_name_scope == null ? _name : _old_name_scope;
+
+ variable_scope_object = new VariableScope(_reuse,
+ name: _new_name,
+ name_scope: name_scope);
+
+ _var_scope_store.open_variable_scope(_new_name);
+ }
+ _var_scope_store.current_scope = variable_scope_object;
+ }
+
+ public void Dispose()
+ {
+
+ }
+
+ public void __exit__()
+ {
+
+ }
+
+ public static implicit operator VariableScope(PureVariableScope scope)
+ {
+ return scope.variable_scope_object;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs
index 5bd4120d..3b2f67c4 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs
@@ -17,7 +17,7 @@ namespace Tensorflow
private static Tensor op_helper(string default_name, RefVariable x, T y)
{
var tensor1 = x.value();
- return with(new ops.name_scope(null, default_name, new { tensor1, y }), scope => {
+ return with(new ops.name_scope(null, default_name, new { tensor1, y }), scope => {
var tensor2 = ops.convert_to_tensor(y, tensor1.dtype.as_base_dtype(), "y");
return gen_math_ops.add(tensor1, tensor2, scope);
});
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
index f0e4e721..5eb03d3f 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -1,4 +1,6 @@
-using System;
+using Google.Protobuf;
+using Google.Protobuf.Collections;
+using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@@ -99,7 +101,7 @@ namespace Tensorflow
if (initial_value is null)
throw new ValueError("initial_value must be specified.");
- var init_from_fn = false;
+ var init_from_fn = initial_value.GetType().Name == "Func`1";
if(collections == null)
{
@@ -115,12 +117,27 @@ namespace Tensorflow
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
ops.init_scope();
- var values = init_from_fn ? new List
+
+
+
+
+
+
+
-
diff --git a/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs b/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
new file mode 100644
index 00000000..f705b98c
--- /dev/null
+++ b/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
@@ -0,0 +1,91 @@
+using NumSharp.Core;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+using System.Text.RegularExpressions;
+
+namespace TensorFlowNET.Examples.CnnTextClassification
+{
+ public class DataHelpers
+ {
+ private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
+ private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";
+
+ public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len)
+ {
+ string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} ";
+ /*if (step == "train")
+ df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/
+ var char_dict = new Dictionary();
+ char_dict[""] = 0;
+ char_dict[""] = 1;
+ foreach (char c in alphabet)
+ char_dict[c.ToString()] = char_dict.Count;
+
+ var contents = File.ReadAllLines(TRAIN_PATH);
+
+ var x = new int[contents.Length][];
+ var y = new int[contents.Length];
+ for (int i = 0; i < contents.Length; i++)
+ {
+ string[] parts = contents[i].ToLower().Split(",\"").ToArray();
+ string content = parts[2];
+ content = content.Substring(0, content.Length - 1);
+ x[i] = new int[document_max_len];
+ for (int j = 0; j < document_max_len; j++)
+ {
+ if (j >= content.Length)
+ x[i][j] = char_dict[""];
+ else
+ x[i][j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict[""];
+ }
+
+ y[i] = int.Parse(parts[0]);
+ }
+
+ return (x, y, alphabet.Length + 2);
+ }
+
+ ///
+ /// Loads MR polarity data from files, splits the data into words and generates labels.
+ /// Returns split sentences and labels.
+ ///
+ ///
+ ///
+ ///
+ public static (string[], NDArray) load_data_and_labels(string positive_data_file, string negative_data_file)
+ {
+ Directory.CreateDirectory("CnnTextClassification");
+ Utility.Web.Download(positive_data_file, "CnnTextClassification", "rt -polarity.pos");
+ Utility.Web.Download(negative_data_file, "CnnTextClassification", "rt-polarity.neg");
+
+ // Load data from files
+ var positive_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.pos")
+ .Select(x => x.Trim())
+ .ToArray();
+
+ var negative_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.neg")
+ .Select(x => x.Trim())
+ .ToArray();
+
+ var x_text = new List();
+ x_text.AddRange(positive_examples);
+ x_text.AddRange(negative_examples);
+ x_text = x_text.Select(x => clean_str(x)).ToList();
+
+ var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray();
+ var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray();
+ var y = np.array(1);// np.concatenate(new int[][][] { positive_labels, negative_labels });
+ return (x_text.ToArray(), y);
+ }
+
+ private static string clean_str(string str)
+ {
+ str = Regex.Replace(str, @"[^A-Za-z0-9(),!?\'\`]", " ");
+ str = Regex.Replace(str, @"\'s", " \'s");
+ return str;
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
new file mode 100644
index 00000000..4ea583d4
--- /dev/null
+++ b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
@@ -0,0 +1,37 @@
+using NumSharp.Core;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+using Tensorflow;
+using TensorFlowNET.Examples.Utility;
+
+namespace TensorFlowNET.Examples.CnnTextClassification
+{
+ ///
+ /// https://github.com/dongjun-Lee/text-classification-models-tf
+ ///
+ public class TextClassificationTrain : Python, IExample
+ {
+ private string dataDir = "text_classification";
+ private string dataFileName = "dbpedia_csv.tar.gz";
+
+ private const int CHAR_MAX_LEN = 1014;
+
+ public void Run()
+ {
+ download_dbpedia();
+ Console.WriteLine("Building dataset...");
+ var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN);
+ //var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15);
+ }
+
+ public void download_dbpedia()
+ {
+ string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
+ Web.Download(url, dataDir, dataFileName);
+ Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs b/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs
index b57da319..cd59287b 100644
--- a/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs
+++ b/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs
@@ -46,9 +46,8 @@ namespace TensorFlowNET.Examples
// get model file
string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}";
- string zipFile = Path.Join(dir, $"imdb.zip");
- Utility.Web.Download(url, zipFile);
- Utility.Compress.UnZip(zipFile, dir);
+ Utility.Web.Download(url, dir, "imdb.zip");
+ Utility.Compress.UnZip(Path.Join(dir, $"imdb.zip"), dir);
// prepare training dataset
var x_train = ReadData(Path.Join(dir, "x_train.txt"));
diff --git a/src/TensorFlowNET.Utility/Compress.cs b/test/TensorFlowNET.Examples/Utility/Compress.cs
similarity index 98%
rename from src/TensorFlowNET.Utility/Compress.cs
rename to test/TensorFlowNET.Examples/Utility/Compress.cs
index 0efe59bd..cf40e2c4 100644
--- a/src/TensorFlowNET.Utility/Compress.cs
+++ b/test/TensorFlowNET.Examples/Utility/Compress.cs
@@ -7,7 +7,7 @@ using System.Linq;
using System.Threading;
using System.Threading.Tasks;
-namespace TensorFlowNET.Utility
+namespace TensorFlowNET.Examples.Utility
{
public class Compress
{
diff --git a/test/TensorFlowNET.Examples/Utility/Web.cs b/test/TensorFlowNET.Examples/Utility/Web.cs
new file mode 100644
index 00000000..e2155e93
--- /dev/null
+++ b/test/TensorFlowNET.Examples/Utility/Web.cs
@@ -0,0 +1,43 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Net;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace TensorFlowNET.Examples.Utility
+{
+ public class Web
+ {
+ public static bool Download(string url, string destDir, string destFileName)
+ {
+ if (destFileName == null)
+ destFileName = url.Split(Path.DirectorySeparatorChar).Last();
+
+ Directory.CreateDirectory(destDir);
+
+ string relativeFilePath = Path.Combine(destDir, destFileName);
+
+ if (File.Exists(relativeFilePath))
+ {
+ Console.WriteLine($"{relativeFilePath} already exists.");
+ return false;
+ }
+
+ var wc = new WebClient();
+ Console.WriteLine($"Downloading {relativeFilePath}");
+ var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
+ while (!download.IsCompleted)
+ {
+ Thread.Sleep(1000);
+ Console.Write(".");
+ }
+ Console.WriteLine("");
+ Console.WriteLine($"Downloaded {relativeFilePath}");
+
+ return true;
+ }
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/CApiTest.cs b/test/TensorFlowNET.UnitTest/CApiTest.cs
index dedb88b3..4f0be55d 100644
--- a/test/TensorFlowNET.UnitTest/CApiTest.cs
+++ b/test/TensorFlowNET.UnitTest/CApiTest.cs
@@ -6,7 +6,7 @@ using Tensorflow;
namespace TensorFlowNET.UnitTest
{
- public class CApiTest
+ public class CApiTest : Python
{
protected TF_Code TF_OK = TF_Code.TF_OK;
protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;
diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs
index 7b0768e0..ee496e7c 100644
--- a/test/TensorFlowNET.UnitTest/ConstantTest.cs
+++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs
@@ -10,7 +10,7 @@ using Tensorflow;
namespace TensorFlowNET.UnitTest
{
[TestClass]
- public class ConstantTest
+ public class ConstantTest : Python
{
Status status = new Status();
@@ -27,7 +27,7 @@ namespace TensorFlowNET.UnitTest
{
string str = "Hello, TensorFlow.NET!";
var tensor = tf.constant(str);
- Python.with(tf.Session(), sess =>
+ with(tf.Session(), sess =>
{
var result = sess.run(tensor);
Assert.IsTrue(result.Data()[0] == str);
@@ -39,7 +39,7 @@ namespace TensorFlowNET.UnitTest
{
// small size
var tensor = tf.zeros(new Shape(3, 2), TF_DataType.TF_INT32, "small");
- Python.with(tf.Session(), sess =>
+ with(tf.Session(), sess =>
{
var result = sess.run(tensor);
@@ -50,7 +50,7 @@ namespace TensorFlowNET.UnitTest
// big size
tensor = tf.zeros(new Shape(200, 100), TF_DataType.TF_INT32, "big");
- Python.with(tf.Session(), sess =>
+ with(tf.Session(), sess =>
{
var result = sess.run(tensor);
@@ -74,7 +74,7 @@ namespace TensorFlowNET.UnitTest
});
var tensor = tf.constant(nd);
- Python.with(tf.Session(), sess =>
+ with(tf.Session(), sess =>
{
var result = sess.run(tensor);
var data = result.Data();
diff --git a/test/TensorFlowNET.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.UnitTest/NameScopeTest.cs
index 64b2bd1f..01a5f031 100644
--- a/test/TensorFlowNET.UnitTest/NameScopeTest.cs
+++ b/test/TensorFlowNET.UnitTest/NameScopeTest.cs
@@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void NestedNameScope()
{
- with(new ops.name_scope("scope1"), scope1 =>
+ with(new ops.name_scope("scope1"), scope1 =>
{
name = scope1;
Assert.AreEqual("scope1", g._name_stack);
@@ -24,7 +24,7 @@ namespace TensorFlowNET.UnitTest
var const1 = tf.constant(1.0);
Assert.AreEqual("scope1/Const:0", const1.name);
- with(new ops.name_scope("scope2"), scope2 =>
+ with(new ops.name_scope("scope2"), scope2 =>
{
name = scope2;
Assert.AreEqual("scope1/scope2", g._name_stack);
diff --git a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs
index 3bc6a892..189a9c69 100644
--- a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs
+++ b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs
@@ -7,7 +7,7 @@ using Tensorflow;
namespace TensorFlowNET.UnitTest
{
[TestClass]
- public class PlaceholderTest
+ public class PlaceholderTest : Python
{
[TestMethod]
public void placeholder()
@@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest
var x = tf.placeholder(tf.int32);
var y = x * 3;
- Python.with(tf.Session(), sess =>
+ with(tf.Session(), sess =>
{
var result = sess.run(y,
new FeedItem(x, 2));
diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs
index e93e4444..37206426 100644
--- a/test/TensorFlowNET.UnitTest/SessionTest.cs
+++ b/test/TensorFlowNET.UnitTest/SessionTest.cs
@@ -82,7 +82,7 @@ namespace TensorFlowNET.UnitTest
var a = constant_op.constant(np.array(3.0).reshape(1, 1));
var b = constant_op.constant(np.array(2.0).reshape(1, 1));
var c = math_ops.matmul(a, b, name: "matmul");
- Python.with(tf.Session(), delegate
+ with(tf.Session(), delegate
{
var result = c.eval();
Assert.AreEqual(6, result.Data()[0]);
diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
index 2ff58949..f495711c 100644
--- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
+++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
@@ -1,7 +1,7 @@
- netcoreapp2.1
+ netcoreapp2.2
false
diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs
index f5aec32b..3ea4dfc8 100644
--- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs
+++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs
@@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest
public void ImportGraph()
{
- with(tf.Session(), sess =>
+ with(tf.Session(), sess =>
{
var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta");
});
@@ -44,7 +44,7 @@ namespace TensorFlowNET.UnitTest
public void ImportSavedModel()
{
- with(Session.LoadFromSavedModel("mobilenet"), sess =>
+ with(Session.LoadFromSavedModel("mobilenet"), sess =>
{
});
@@ -65,7 +65,7 @@ namespace TensorFlowNET.UnitTest
// Add ops to save and restore all the variables.
var saver = tf.train.Saver();
- with(tf.Session(), sess =>
+ with(tf.Session(), sess =>
{
sess.run(init_op);
diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs
index 88c65bd3..7713e774 100644
--- a/test/TensorFlowNET.UnitTest/VariableTest.cs
+++ b/test/TensorFlowNET.UnitTest/VariableTest.cs
@@ -30,6 +30,47 @@ namespace TensorFlowNET.UnitTest
var mammal2 = tf.Variable("Tiger");
}
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/variable_scope
+ /// how to create a new variable
+ ///
+ [TestMethod]
+ public void VarCreation()
+ {
+ with(tf.variable_scope("foo"), delegate
+ {
+ with(tf.variable_scope("bar"), delegate
+ {
+ var v = tf.get_variable("v", new TensorShape(1));
+ Assert.AreEqual(v.name, "foo/bar/v:0");
+ });
+ });
+ }
+
+ ///
+ /// how to reenter a premade variable scope safely
+ ///
+ [TestMethod]
+ public void ReenterVariableScope()
+ {
+ variable_scope vs = null;
+ with(tf.variable_scope("foo"), v => vs = v);
+
+ // Re-enter the variable scope.
+ with(tf.variable_scope(vs, auxiliary_name_scope: false), v =>
+ {
+ var vs1 = (VariableScope)v;
+ // Restore the original name_scope.
+ with(tf.name_scope(vs1.original_name_scope), delegate
+ {
+ var v1 = tf.get_variable("v", new TensorShape(1));
+ Assert.AreEqual(v1.name, "foo/v:0");
+ var c1 = tf.constant(new int[] { 1 }, name: "c");
+ Assert.AreEqual(c1.name, "foo/c:0");
+ });
+ });
+ }
+
[TestMethod]
public void ScalarVar()
{
@@ -49,7 +90,7 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void Assign1()
{
- with(tf.Graph().as_default(), graph =>
+ with(tf.Graph().as_default(), graph =>
{
var variable = tf.Variable(31, name: "tree");
var init = tf.global_variables_initializer();
@@ -75,7 +116,7 @@ namespace TensorFlowNET.UnitTest
// Add an op to initialize the variables.
var init_op = tf.global_variables_initializer();
- with(tf.Session(), sess =>
+ with(tf.Session(), sess =>
{
sess.run(init_op);
// o some work with the model.