diff --git a/README.md b/README.md index 42cd3ec4..7a861d04 100644 --- a/README.md +++ b/README.md @@ -66,13 +66,14 @@ using(var sess = tf.Session()) Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflownet.readthedocs.io/en/latest/FrontCover.html). -More examples: +### More examples: * [Hello World](test/TensorFlowNET.Examples/HelloWorld.cs) * [Basic Operations](test/TensorFlowNET.Examples/BasicOperations.cs) * [Image Recognition](test/TensorFlowNET.Examples/ImageRecognition.cs) * [Linear Regression](test/TensorFlowNET.Examples/LinearRegression.cs) * [Text Classification](test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs) +* [CNN Text Classification](test/TensorFlowNET.Examples/CnnTextClassification.cs) * [Naive Bayes Classification](test/TensorFlowNET.Examples/NaiveBayesClassifier.cs) * [Named Entity Recognition](test/TensorFlowNET.Examples/NamedEntityRecognition.cs) diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 24afc77e..25a97e6d 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -9,11 +9,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Utility", "src\TensorFlowNET.Utility\TensorFlowNET.Utility.csproj", "{00D9085C-0FC7-453C-A0CC-BAD98F44FEA0}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{268BF0B6-0AA9-4FD3-A245-7AF336F1E3E9}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{E8340C61-12C1-4BEE-A340-403E7C1ACD82}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "scikit-learn", "..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj", "{199DDAD8-4A6F-43B3-A560-C0393619E304}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -33,18 +33,18 @@ Global {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU - {00D9085C-0FC7-453C-A0CC-BAD98F44FEA0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {00D9085C-0FC7-453C-A0CC-BAD98F44FEA0}.Debug|Any CPU.Build.0 = Debug|Any CPU - {00D9085C-0FC7-453C-A0CC-BAD98F44FEA0}.Release|Any CPU.ActiveCfg = Release|Any CPU - {00D9085C-0FC7-453C-A0CC-BAD98F44FEA0}.Release|Any CPU.Build.0 = Release|Any CPU {4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.Build.0 = Debug|Any CPU {4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.ActiveCfg = Release|Any CPU {4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.Build.0 = Release|Any CPU - {268BF0B6-0AA9-4FD3-A245-7AF336F1E3E9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {268BF0B6-0AA9-4FD3-A245-7AF336F1E3E9}.Debug|Any CPU.Build.0 = Debug|Any CPU - {268BF0B6-0AA9-4FD3-A245-7AF336F1E3E9}.Release|Any CPU.ActiveCfg = Release|Any CPU - {268BF0B6-0AA9-4FD3-A245-7AF336F1E3E9}.Release|Any CPU.Build.0 = Release|Any CPU + {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.Build.0 = Release|Any CPU + {199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.Build.0 = Debug|Any CPU + {199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.ActiveCfg = Release|Any CPU + {199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index 584f3b01..19876d62 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -7,7 +7,24 @@ namespace Tensorflow public static partial class tf { public static IInitializer zeros_initializer => new Zeros(); + public static IInitializer glorot_uniform_initializer => new GlorotUniform(); + public static variable_scope variable_scope(string name, + string default_name = null, + object values = null, + bool auxiliary_name_scope = true) => new variable_scope(name, + default_name, + values, + auxiliary_name_scope); + + public static variable_scope variable_scope(VariableScope scope, + string default_name = null, + object values = null, + bool auxiliary_name_scope = true) => new variable_scope(scope, + default_name, + values, + auxiliary_name_scope); + public class Zeros : IInitializer { private TF_DataType dtype; @@ -30,5 +47,105 @@ namespace Tensorflow return new { dtype = dtype.name() }; } } + + /// + /// Initializer capable of adapting its scale to the shape of weights tensors. + /// + public class VarianceScaling : IInitializer + { + protected float _scale; + protected string _mode; + protected string _distribution; + protected int? _seed; + protected TF_DataType _dtype; + + public VarianceScaling(float scale = 1.0f, + string mode = "fan_in", + string distribution= "truncated_normal", + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) + { + if (scale < 0) + throw new ValueError("`scale` must be positive float."); + _scale = scale; + _mode = mode; + _distribution = distribution; + _seed = seed; + _dtype = dtype; + } + + public Tensor call(TensorShape shape, TF_DataType dtype) + { + var (fan_in, fan_out) = _compute_fans(shape); + if (_mode == "fan_in") + _scale /= Math.Max(1, fan_in); + else if (_mode == "fan_out") + _scale /= Math.Max(1, fan_out); + else + _scale /= Math.Max(1, (fan_in + fan_out) / 2); + + if (_distribution == "normal" || _distribution == "truncated_normal") + { + throw new NotImplementedException("truncated_normal"); + } + else if(_distribution == "untruncated_normal") + { + throw new NotImplementedException("truncated_normal"); + } + else + { + var limit = Math.Sqrt(3.0f * _scale); + return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); + } + } + + private (int, int) _compute_fans(int[] shape) + { + if (shape.Length < 1) + return (1, 1); + if (shape.Length == 1) + return (shape[0], shape[0]); + if (shape.Length == 2) + return (shape[0], shape[1]); + else + throw new NotImplementedException("VarianceScaling._compute_fans"); + } + + public virtual object get_config() + { + return new + { + scale = _scale, + mode = _mode, + distribution = _distribution, + seed = _seed, + dtype = _dtype + }; + } + } + + public class GlorotUniform : VarianceScaling + { + public GlorotUniform(float scale = 1.0f, + string mode = "fan_avg", + string distribution = "uniform", + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype) + { + + } + + public object get_config() + { + return new + { + scale = _scale, + mode = _mode, + distribution = _distribution, + seed = _seed, + dtype = _dtype + }; + } + } } } diff --git a/src/TensorFlowNET.Core/Contrib/Learn/Preprocessing/VocabularyProcessor.cs b/src/TensorFlowNET.Core/Contrib/Learn/Preprocessing/VocabularyProcessor.cs new file mode 100644 index 00000000..00f2c9b0 --- /dev/null +++ b/src/TensorFlowNET.Core/Contrib/Learn/Preprocessing/VocabularyProcessor.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Contrib.Learn.Preprocessing +{ + public class VocabularyProcessor + { + private int _max_document_length; + private int _min_frequency; + + public VocabularyProcessor(int max_document_length, + int min_frequency) + { + _max_document_length = max_document_length; + _min_frequency = min_frequency; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/importer.py.cs b/src/TensorFlowNET.Core/Framework/importer.py.cs index a0a45ca0..9070d27b 100644 --- a/src/TensorFlowNET.Core/Framework/importer.py.cs +++ b/src/TensorFlowNET.Core/Framework/importer.py.cs @@ -7,7 +7,7 @@ using static Tensorflow.OpDef.Types; namespace Tensorflow { - public class importer + public class importer : Python { public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, Dictionary input_map = null, @@ -26,7 +26,7 @@ namespace Tensorflow string prefix = ""; var graph = ops.get_default_graph(); - Python.with(new ops.name_scope(name, "import", input_map.Values), scope => + with(new ops.name_scope(name, "import", input_map.Values), scope => { prefix = scope; /*if (!string.IsNullOrEmpty(prefix)) diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index e4993b5e..a4af9caa 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -7,7 +7,7 @@ using System.Threading; namespace Tensorflow { - public class gradients_impl + public class gradients_impl : Python { public static Tensor[] gradients(Tensor[] ys, Tensor[] xs, @@ -58,7 +58,7 @@ namespace Tensorflow **/ var grads = new Dictionary(); - Python.with(new ops.name_scope(name, "gradients", values: all), scope => + with(new ops.name_scope(name, "gradients", values: all), scope => { string grad_scope = scope; // Get a uid for this call to gradients that can be used to help @@ -131,7 +131,7 @@ namespace Tensorflow // for ops that do not have gradients. var grad_fn = ops.get_gradient_function(op); - Python.with(new ops.name_scope(op.name + "_grad"), scope1 => + with(new ops.name_scope(op.name + "_grad"), scope1 => { string name1 = scope1; if (grad_fn != null) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index a025bfc6..5caaf6b0 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -196,11 +196,11 @@ namespace Tensorflow _create_op_helper(op, true); - Console.Write($"create_op: {op_type} '{node_def.Name}'"); + /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); - Console.WriteLine(); + Console.WriteLine();*/ return op; } diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs index e45074b0..02d7f180 100644 --- a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs @@ -12,7 +12,7 @@ namespace Tensorflow string scope = "", string loss_collection= "losses") { - with(new ops.name_scope(scope, + with(new ops.name_scope(scope, "sparse_softmax_cross_entropy_loss", (logits, labels, weights)), namescope => diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 0c19dea0..8e6bfba0 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -10,7 +10,7 @@ using static Tensorflow.OpDef.Types; namespace Tensorflow { - public class OpDefLibrary + public class OpDefLibrary : Python { public Operation _apply_op_helper(string op_type_name, string name = null, dynamic args = null) { @@ -44,7 +44,7 @@ namespace Tensorflow var input_types = new List(); dynamic values = null; - return Python.with(new ops.name_scope(name), scope => + return with(new ops.name_scope(name), scope => { var inferred_from = new Dictionary(); var base_types = new List(); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 66574dae..18de83dd 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -5,14 +5,14 @@ using System.Text; namespace Tensorflow { - public class array_ops + public class array_ops : Python { public static Tensor placeholder_with_default(T input, int[] shape, string name = null) => gen_array_ops.placeholder_with_default(input, shape, name); public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { dtype = dtype.as_base_dtype(); - return Python.with(new ops.name_scope(name, "zeros", shape), scope => + return with(new ops.name_scope(name, "zeros", shape), scope => { name = scope; switch (dtype) @@ -68,7 +68,7 @@ namespace Tensorflow private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) { - return Python.with(new ops.name_scope(name, "ones_like", new { tensor }), scope => + return with(new ops.name_scope(name, "ones_like", new { tensor }), scope => { name = scope; var tensor1 = ops.convert_to_tensor(tensor, name: "tensor"); @@ -84,7 +84,7 @@ namespace Tensorflow public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { dtype = dtype.as_base_dtype(); - return Python.with(new ops.name_scope(name, "ones", new { shape }), scope => + return with(new ops.name_scope(name, "ones", new { shape }), scope => { name = scope; var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); @@ -130,7 +130,7 @@ namespace Tensorflow private static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) { - return Python.with(new ops.name_scope(name, "Shape", new { input }), scope => + return with(new ops.name_scope(name, "Shape", new { input }), scope => { name = scope; @@ -151,7 +151,7 @@ namespace Tensorflow private static Tensor size_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) { - return Python.with(new ops.name_scope(name, "Size", new Tensor[] { input }), scope => + return with(new ops.name_scope(name, "Size", new Tensor[] { input }), scope => { name = scope; @@ -182,7 +182,7 @@ namespace Tensorflow public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) { - return Python.with(new ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => + return with(new ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => { name = scope; tensor = ops.convert_to_tensor(tensor, name: "tensor"); diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index c9220868..303104b5 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -9,7 +9,7 @@ namespace Tensorflow { public static Operation group(T[] inputs, string name = null) where T : ITensorOrOperation { - return with(new ops.name_scope(name, "group_deps", inputs), scope => + return with(new ops.name_scope(name, "group_deps", inputs), scope => { name = scope; @@ -39,7 +39,7 @@ namespace Tensorflow private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = null) { - return Python.with<_ControlDependenciesController, Operation>(ops.control_dependencies(deps), ctl => + return with(ops.control_dependencies(deps), ctl => { if (dev == null) { @@ -83,7 +83,7 @@ namespace Tensorflow public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null) { - return Python.with(new ops.name_scope(name, "tuple", tensors), scope => + return with(new ops.name_scope(name, "tuple", tensors), scope => { name = scope; var gating_ops = tensors.Select(x => x.op).ToList(); @@ -115,11 +115,11 @@ namespace Tensorflow values.AddRange(dependencies); values.Add(output_tensor); - return Python.with(new ops.name_scope(name, "control_dependency", values), scope => + return with(new ops.name_scope(name, "control_dependency", values), scope => { name = scope; - return Python.with<_ControlDependenciesController, Tensor>(ops.control_dependencies(dependencies), ctl => + return with(ops.control_dependencies(dependencies), ctl => { output_tensor = ops.convert_to_tensor_or_composite(output_tensor); return _Identity(output_tensor, name: name); diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs index 8cf40fd1..354177f9 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs @@ -24,10 +24,34 @@ namespace Tensorflow if (!seed2.HasValue) seed2 = 0; - var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", name: name, + var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", + name: name, args: new { shape, dtype, seed, seed2 }); return _op.outputs[0]; } + + /// + /// Outputs random values from a uniform distribution. + /// + /// + /// + /// + /// + /// + /// + public static Tensor random_uniform(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) + { + if (!seed.HasValue) + seed = 0; + if (!seed2.HasValue) + seed2 = 0; + + var _op = _op_def_lib._apply_op_helper("RandomUniform", + name: name, + args: new { shape, dtype, seed, seed2}); + + return _op.outputs[0]; + } } } diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs index b0015983..fb9d097a 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs @@ -14,7 +14,7 @@ namespace Tensorflow if(base_type == x.dtype) return x; - return with(new ops.name_scope(name, "Cast", new { x }), scope => + return with(new ops.name_scope(name, "Cast", new { x }), scope => { x = ops.convert_to_tensor(x, name: "x"); if (x.dtype.as_base_dtype() != base_type) @@ -165,7 +165,7 @@ namespace Tensorflow if (delta == null) delta = 1; - return with(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope => + return with(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope => { name = scope; var start1 = ops.convert_to_tensor(start, name: "start"); @@ -178,7 +178,7 @@ namespace Tensorflow public static Tensor floordiv(Tensor x, Tensor y, string name = null) { - return with(new ops.name_scope(name, "floordiv", new { x, y }), scope => + return with(new ops.name_scope(name, "floordiv", new { x, y }), scope => { return gen_math_ops.floor_div(x, y, scope); }); @@ -186,7 +186,7 @@ namespace Tensorflow public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true) { - return with(new ops.name_scope(name, "Rank", new List { input }), scope => + return with(new ops.name_scope(name, "Rank", new List { input }), scope => { name = scope; var input_tensor = ops.convert_to_tensor(input); @@ -206,7 +206,7 @@ namespace Tensorflow { Tensor result = null; - Python.with(new ops.name_scope(name, "MatMul", new Tensor[] { a, b }), scope => + with(new ops.name_scope(name, "MatMul", new Tensor[] { a, b }), scope => { name = scope; @@ -236,7 +236,7 @@ namespace Tensorflow if (dt.is_floating() || dt.is_integer()) return x; - return Python.with(new ops.name_scope(name, "Conj", new List { x }), scope => + return with(new ops.name_scope(name, "Conj", new List { x }), scope => { return x; diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs index 299ea3f7..c6ea3aa9 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs @@ -4,8 +4,18 @@ using System.Text; namespace Tensorflow { - public class random_ops + public class random_ops : Python { + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// public static Tensor random_normal(int[] shape, float mean = 0.0f, float stddev = 1.0f, @@ -13,7 +23,7 @@ namespace Tensorflow int? seed = null, string name = null) { - return Python.with(new ops.name_scope(name, "random_normal", new object[] { shape, mean, stddev }), scope => + return with(new ops.name_scope(name, "random_normal", new { shape, mean, stddev }), scope => { var shape_tensor = _ShapeTensor(shape); var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); @@ -26,6 +36,34 @@ namespace Tensorflow }); } + /// + /// Outputs random values from a uniform distribution. + /// + /// + /// + /// + /// The type of the output + /// Used to create a random seed for the distribution. + /// A name for the operation + /// A tensor of the specified shape filled with random uniform values. + public static Tensor random_uniform(int[] shape, + float minval = 0, + float? maxval = null, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) + { + return with(new ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => + { + name = scope; + var tensorShape = _ShapeTensor(shape); + var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); + var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); + var rnd = gen_random_ops.random_uniform(tensorShape, dtype); + return math_ops.add(rnd * (maxTensor - minTensor), minTensor, name: name); + }); + } + private static Tensor _ShapeTensor(int[] shape) { return ops.convert_to_tensor(shape, name: "shape"); diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index 8aa61222..b11f2889 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -43,12 +43,12 @@ namespace Tensorflow } } - public static void with(IPython py, Action action) where T : IPython + public static void with(T py, Action action) where T : IPython { try { py.__enter__(); - action((T)py); + action(py); } catch (Exception ex) { @@ -62,12 +62,12 @@ namespace Tensorflow } } - public static TOut with(IPython py, Func action) where TIn : IPython + public static TOut with(TIn py, Func action) where TIn : IPython { try { py.__enter__(); - return action((TIn)py); + return action(py); } catch (Exception ex) { diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 198cd509..5c84f34a 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -62,26 +62,29 @@ namespace Tensorflow switch (subfeed_val) { - case IntPtr pointer: - feed_dict_tensor[subfeed_t] = pointer; + case IntPtr val: + feed_dict_tensor[subfeed_t] = val; break; - case NDArray nd: - feed_dict_tensor[subfeed_t] = nd; + case NDArray val: + feed_dict_tensor[subfeed_t] = val; break; - case float floatVal: - feed_dict_tensor[subfeed_t] = (NDArray)floatVal; + case float val: + feed_dict_tensor[subfeed_t] = (NDArray)val; break; - case double doubleVal: - feed_dict_tensor[subfeed_t] = (NDArray)doubleVal; + case double val: + feed_dict_tensor[subfeed_t] = (NDArray)val; break; - case int intVal: - feed_dict_tensor[subfeed_t] = (NDArray)intVal; + case short val: + feed_dict_tensor[subfeed_t] = (NDArray)val; break; - case string str: - feed_dict_tensor[subfeed_t] = (NDArray)str; + case int val: + feed_dict_tensor[subfeed_t] = (NDArray)val; break; - case byte[] bytes: - feed_dict_tensor[subfeed_t] = (NDArray)bytes; + case string val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case byte[] val: + feed_dict_tensor[subfeed_t] = (NDArray)val; break; default: Console.WriteLine($"can't handle data type of subfeed_val"); diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 01448769..949166c1 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -4,7 +4,7 @@ netstandard2.0 TensorFlow.NET Tensorflow - 0.4.0 + 0.4.2 Haiping Chen SciSharp STACK true @@ -16,11 +16,11 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET Google's TensorFlow binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.4.0.0 - Added Linear Regression example. - + 0.4.2.0 + Added ConfigProto to control CPU and GPU resource. +Fixed import name scope issue. 7.2 - 0.4.0.0 + 0.4.2.0 diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index 89a8aa17..a357e5d5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -42,7 +42,7 @@ namespace Tensorflow dtype = tr.dtype.as_base_dtype(); var namescope = new ops.name_scope(null, name, new { x, y }); - return Python.with(namescope, scope => + return with(namescope, scope => { Tensor result = null; var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index e2dd55a5..2b3db534 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -12,7 +12,7 @@ namespace Tensorflow /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// - public partial class Tensor : IDisposable, ITensorOrOperation + public partial class Tensor : Python, IDisposable, ITensorOrOperation { private readonly IntPtr _handle; @@ -77,6 +77,11 @@ namespace Tensorflow return null; } + public TensorShape getShape() + { + return tensor_util.to_shape(shape); + } + /// /// number of dimensions /// 0 Scalar (magnitude only) diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index 25c443f3..407315e2 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -12,7 +12,7 @@ namespace Tensorflow /// class directly, but instead instantiate one of its subclasses such as /// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. /// - public abstract class Optimizer + public abstract class Optimizer : Python { // Values for gate_gradients. public static int GATE_NONE = 0; @@ -87,7 +87,7 @@ namespace Tensorflow _create_slots(var_list); var update_ops = new List(); - return Python.with(new ops.name_scope(name, Name), scope => + return with(new ops.name_scope(name, Name), scope => { name = scope; _prepare(); @@ -98,7 +98,7 @@ namespace Tensorflow continue; var scope_name = var.op.name; - Python.with(new ops.name_scope("update_" + scope_name), scope2 => + with(new ops.name_scope("update_" + scope_name), scope2 => { update_ops.Add(processor.update_op(this, grad)); }); diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs index 06a1fb49..876bf856 100644 --- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow { - public class BaseSaverBuilder + public class BaseSaverBuilder : Python { protected SaverDef.Types.CheckpointFormatVersion _write_version; @@ -79,7 +79,7 @@ namespace Tensorflow Tensor save_tensor = null; Operation restore_op = null; - return Python.with(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => + return with(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => { name = scope; diff --git a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs new file mode 100644 index 00000000..6f97b19d --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs @@ -0,0 +1,86 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class PureVariableScope : IPython + { + private string _name; + private VariableScope _scope; + private string _new_name; + private string _old_name_scope; + private bool _reuse; + private _VariableStore _var_store; + private VariableScope _old; + private _VariableScopeStore _var_scope_store; + private VariableScope variable_scope_object; + private VariableScope _cached_variable_scope_object; + + public PureVariableScope(string name, + string old_name_scope = null, + TF_DataType dtype = TF_DataType.DtInvalid) + { + _name = name; + _old_name_scope = old_name_scope; + _var_store = variable_scope._get_default_variable_store(); + _var_scope_store = variable_scope.get_variable_scope_store(); + } + + public PureVariableScope(VariableScope scope, + string old_name_scope = null, + TF_DataType dtype = TF_DataType.DtInvalid) + { + _scope = scope; + _old_name_scope = old_name_scope; + _var_store = variable_scope._get_default_variable_store(); + _var_scope_store = variable_scope.get_variable_scope_store(); + _new_name = _scope._name; + + string name_scope = _scope._name_scope; + variable_scope_object = new VariableScope(_reuse, + name: _new_name, + name_scope: name_scope); + + _cached_variable_scope_object = variable_scope_object; + } + + public void __enter__() + { + _old = _var_scope_store.current_scope; + if(_scope != null) + { + _var_scope_store.open_variable_scope(_new_name); + variable_scope_object = _cached_variable_scope_object; + } + else + { + _new_name = string.IsNullOrEmpty(_old._name) ? _name : _old._name + "/" + _name; + _reuse = _reuse || _old.resue; + string name_scope = _old_name_scope == null ? _name : _old_name_scope; + + variable_scope_object = new VariableScope(_reuse, + name: _new_name, + name_scope: name_scope); + + _var_scope_store.open_variable_scope(_new_name); + } + _var_scope_store.current_scope = variable_scope_object; + } + + public void Dispose() + { + + } + + public void __exit__() + { + + } + + public static implicit operator VariableScope(PureVariableScope scope) + { + return scope.variable_scope_object; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs index 5bd4120d..3b2f67c4 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs @@ -17,7 +17,7 @@ namespace Tensorflow private static Tensor op_helper(string default_name, RefVariable x, T y) { var tensor1 = x.value(); - return with(new ops.name_scope(null, default_name, new { tensor1, y }), scope => { + return with(new ops.name_scope(null, default_name, new { tensor1, y }), scope => { var tensor2 = ops.convert_to_tensor(y, tensor1.dtype.as_base_dtype(), "y"); return gen_math_ops.add(tensor1, tensor2, scope); }); diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index f0e4e721..5eb03d3f 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -1,4 +1,6 @@ -using System; +using Google.Protobuf; +using Google.Protobuf.Collections; +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -99,7 +101,7 @@ namespace Tensorflow if (initial_value is null) throw new ValueError("initial_value must be specified."); - var init_from_fn = false; + var init_from_fn = initial_value.GetType().Name == "Func`1"; if(collections == null) { @@ -115,12 +117,27 @@ namespace Tensorflow collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); ops.init_scope(); - var values = init_from_fn ? new List() : new List { initial_value }; - Python.with(new ops.name_scope(name, "Variable", values), scope => + var values = init_from_fn ? new object[0] : new object[] { initial_value }; + with(new ops.name_scope(name, "Variable", values), scope => { + name = scope; if (init_from_fn) { - + // Use attr_scope and device(None) to simulate the behavior of + // colocate_with when the variable we want to colocate with doesn't + // yet exist. + string true_name = ops._name_from_scope_name(name); + var attr = new AttrValue + { + List = new AttrValue.Types.ListValue() + }; + attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); + with(new ops.name_scope("Initializer"), scope2 => + { + _initial_value = (initial_value as Func)(); + _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); + _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); + }); } // Or get the initial value from a Tensor or Python object. else @@ -135,7 +152,9 @@ namespace Tensorflow // Manually overrides the variable's shape with the initial value's. if (validate_shape) { - var initial_value_shape = _initial_value.shape; + var initial_value_shape = _initial_value.getShape(); + if (!initial_value_shape.is_fully_defined()) + throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); } // If 'initial_value' makes use of other variables, make sure we don't diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index c8d99036..29c03c19 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -4,17 +4,27 @@ using System.Text; namespace Tensorflow { - public class VariableScope + /// + /// Variable scope object to carry defaults to provide to `get_variable` + /// + public class VariableScope : Python { public bool use_resource { get; set; } - private _ReuseMode _reuse { get; set; } + private _ReuseMode _reuse; + public bool resue; - private object _regularizer; private TF_DataType _dtype; - public string name { get; set; } + public string _name { get; set; } + public string _name_scope { get; set; } + public string original_name_scope => _name_scope; - public VariableScope(TF_DataType dtype = TF_DataType.TF_FLOAT) + public VariableScope(bool reuse, + string name = "", + string name_scope = "", + TF_DataType dtype = TF_DataType.TF_FLOAT) { + _name = name; + _name_scope = name_scope; _reuse = _ReuseMode.AUTO_REUSE; _dtype = dtype; } @@ -28,8 +38,8 @@ namespace Tensorflow VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation= VariableAggregation.NONE) { - string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; - return Python.with(new ops.name_scope(""), scope => + string full_name = !string.IsNullOrEmpty(this._name) ? this._name + "/" + name : name; + return with(new ops.name_scope(null), scope => { if (dtype == TF_DataType.DtInvalid) dtype = _dtype; diff --git a/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs b/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs index a7b3e3b5..a043bfe0 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs @@ -7,10 +7,20 @@ namespace Tensorflow public class _VariableScopeStore { public VariableScope current_scope { get; set; } + private Dictionary variable_scopes_count; public _VariableScopeStore() { - current_scope = new VariableScope(); + current_scope = new VariableScope(false); + variable_scopes_count = new Dictionary(); + } + + public void open_variable_scope(string scope_name) + { + if (variable_scopes_count.ContainsKey(scope_name)) + variable_scopes_count[scope_name] += 1; + else + variable_scopes_count[scope_name] = 1; } } } diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index 2c22f25c..5bd8c86d 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -74,7 +74,9 @@ namespace Tensorflow VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation = VariableAggregation.NONE) { - bool initializing_from_value = false; + bool initializing_from_value = true; + if (use_resource == null) + use_resource = false; if (_vars.ContainsKey(name)) { @@ -86,7 +88,18 @@ namespace Tensorflow throw new NotImplementedException("_get_single_variable"); } - Tensor init_val = null; + RefVariable v = null; + // Create the tensor to initialize the variable with default value. + if (initializer == null) + { + if (dtype.is_floating()) + { + initializer = tf.glorot_uniform_initializer; + initializing_from_value = false; + } + } + + // Create the variable. ops.init_scope(); { if (initializing_from_value) @@ -95,23 +108,19 @@ namespace Tensorflow } else { - init_val = initializer.call(shape, dtype); + Func init_val = () => initializer.call(shape, dtype); var variable_dtype = dtype.as_base_dtype(); + + v = variable_scope.default_variable_creator(init_val, + name: name, + trainable: trainable, + dtype: TF_DataType.DtInvalid, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); } } - // Create the variable. - if (use_resource == null) - use_resource = false; - - var v = variable_scope.default_variable_creator(init_val, - name: name, - trainable: trainable, - dtype: TF_DataType.DtInvalid, - validate_shape: validate_shape, - synchronization: synchronization, - aggregation: aggregation); - _vars[name] = v; return v; diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs new file mode 100644 index 00000000..0144a138 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class state_ops + { + /// + /// Create a variable Operation. + /// + /// + /// + /// + /// + /// + /// + public static Tensor variable_op_v2(long[] shape, + TF_DataType dtype, + string name = "Variable", + string container = "", + string shared_name = "") => gen_state_ops.variable_v2(shape, + dtype, + name: name, + container: container, + shared_name: shared_name); + } +} diff --git a/src/TensorFlowNET.Core/Variables/tf.variable.cs b/src/TensorFlowNET.Core/Variables/tf.variable.cs index 2e7eefec..b61b558e 100644 --- a/src/TensorFlowNET.Core/Variables/tf.variable.cs +++ b/src/TensorFlowNET.Core/Variables/tf.variable.cs @@ -20,8 +20,8 @@ namespace Tensorflow VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation = VariableAggregation.NONE) { - var scope = variable_scope.get_variable_scope(); - var store = variable_scope._get_default_variable_store(); + var scope = Tensorflow.variable_scope.get_variable_scope(); + var store = Tensorflow.variable_scope._get_default_variable_store(); return scope.get_variable(store, name, shape: shape, diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index 56439db5..779e647b 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -1,15 +1,104 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow { - public class variable_scope + public class variable_scope : IPython { public static string _VARSTORE_KEY = "__variable_store"; public static string _VARSCOPESTORE_KEY = "__varscope"; public static bool _DEFAULT_USE_RESOURCE = false; + private bool _use_resource; + public bool UseResource => _use_resource; + private string _name; + private VariableScope _scope; + private string _default_name; + private object _values; + private ops.name_scope _current_name_scope; + private bool _auxiliary_name_scope; + private PureVariableScope _cached_pure_variable_scope; + + public variable_scope(string name, + string default_name = "", + object values = null, + bool auxiliary_name_scope = true) + { + _name = name; + _default_name = default_name; + _values = values; + _current_name_scope = null; + + _use_resource = false; + if (_default_name == null && _name == null) + throw new TypeError("If default_name is None then name is required"); + + _auxiliary_name_scope = auxiliary_name_scope; + } + + public variable_scope(VariableScope scope, + string default_name = "", + object values = null, + bool auxiliary_name_scope = true) + { + _scope = scope; + _default_name = default_name; + _values = values; + _current_name_scope = null; + + _use_resource = false; + if (_default_name == null && _scope == null) + throw new TypeError("If default_name is None then scope is required"); + + _auxiliary_name_scope = auxiliary_name_scope; + } + + public void __enter__() + { + _scope = _enter_scope_uncached(); + } + + private VariableScope _enter_scope_uncached() + { + ops.name_scope current_name_scope; + if (_auxiliary_name_scope) + // Create a new name scope later + current_name_scope = null; + else + { + // Reenter the current name scope + string name_scope = ops.get_name_scope(); + if(!string.IsNullOrEmpty(name_scope)) + // Hack to reenter + name_scope += "/"; + current_name_scope = new ops.name_scope(name_scope); + } + + if (_name != null || _scope != null) + { + var name_scope = _name == null ? _scope._name.Split('/').Last() : _name; + if (name_scope != null || current_name_scope != null) + current_name_scope = new ops.name_scope(name_scope); + current_name_scope.__enter__(); + var current_name_scope_name = current_name_scope; + _current_name_scope = current_name_scope; + string old_name_scope = current_name_scope_name; + PureVariableScope pure_variable_scope = null; + if(_scope == null) + pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope); + else + pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope); + pure_variable_scope.__enter__(); + VariableScope entered_pure_variable_scope = pure_variable_scope; + _cached_pure_variable_scope = pure_variable_scope; + return entered_pure_variable_scope; + } + + throw new NotImplementedException("_enter_scope_uncached"); + } + public static RefVariable default_variable_creator(object initial_value, string name = null, bool? trainable = null, @@ -101,5 +190,22 @@ namespace Tensorflow return trainable.Value; } + + public static implicit operator VariableScope(variable_scope scope) + { + return scope._scope; + } + + public void __exit__() + { + if (_current_name_scope != null) + _current_name_scope.__exit__(); + } + + public void Dispose() + { + if (_current_name_scope != null) + _current_name_scope.Dispose(); + } } } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 6f2408db..50fd59c9 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -12,7 +12,7 @@ using System.ComponentModel; namespace Tensorflow { - public partial class ops + public partial class ops : Python { public static void add_to_collection(string name, T value) { @@ -216,7 +216,7 @@ namespace Tensorflow // inner_device_stack = default_graph._device_function_stack // var outer_context = default_graph.as_default; - Python.with(ops.control_dependencies(null), delegate + with(ops.control_dependencies(null), delegate { var outer_graph = get_default_graph(); // outer_device_stack = None @@ -475,5 +475,11 @@ namespace Tensorflow return name; } } + + public static string get_name_scope() + { + var g = get_default_graph(); + return g.get_name_scope(); + } } } diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 119a884e..eeb415f5 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -20,7 +20,7 @@ namespace Tensorflow public static RefVariable Variable(T data, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) { - return variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid); + return Tensorflow.variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid); } public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) diff --git a/src/TensorFlowNET.Utility/TensorFlowNET.Utility.csproj b/src/TensorFlowNET.Utility/TensorFlowNET.Utility.csproj deleted file mode 100644 index efeb641a..00000000 --- a/src/TensorFlowNET.Utility/TensorFlowNET.Utility.csproj +++ /dev/null @@ -1,13 +0,0 @@ - - - - netstandard2.0 - TensorFlowNET.Utility - TensorFlowNET.Utility - - - - - - - diff --git a/src/TensorFlowNET.Utility/Web.cs b/src/TensorFlowNET.Utility/Web.cs deleted file mode 100644 index dfaf5236..00000000 --- a/src/TensorFlowNET.Utility/Web.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Net; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace TensorFlowNET.Utility -{ - public class Web - { - public static bool Download(string url, string file) - { - if (File.Exists(file)) - { - Console.WriteLine($"{file} already exists."); - return false; - } - - var wc = new WebClient(); - Console.WriteLine($"Downloading {file}"); - var download = Task.Run(() => wc.DownloadFile(url, file)); - while (!download.IsCompleted) - { - Thread.Sleep(1000); - Console.Write("."); - } - Console.WriteLine(""); - Console.WriteLine($"Downloaded {file}"); - - return true; - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageRecognition.cs b/test/TensorFlowNET.Examples/ImageRecognition.cs index 73bf0ab1..47d0ac07 100644 --- a/test/TensorFlowNET.Examples/ImageRecognition.cs +++ b/test/TensorFlowNET.Examples/ImageRecognition.cs @@ -39,7 +39,7 @@ namespace TensorFlowNET.Examples var idx = 0; float propability = 0; - with(tf.Session(graph), sess => + with(tf.Session(graph), sess => { var results = sess.run(output_operation.outputs[0], new FeedItem(input_operation.outputs[0], tensor)); var probabilities = results.Data(); @@ -63,7 +63,7 @@ namespace TensorFlowNET.Examples int input_mean = 117, int input_std = 1) { - return with(tf.Graph().as_default(), graph => + return with(tf.Graph().as_default(), graph => { var file_reader = tf.read_file(file_name, "file_reader"); var decodeJpeg = tf.image.decode_jpeg(file_reader, channels: 3, name: "DecodeJpeg"); @@ -74,7 +74,7 @@ namespace TensorFlowNET.Examples var sub = tf.subtract(bilinear, new float[] { input_mean }); var normalized = tf.divide(sub, new float[] { input_std }); - return with(tf.Session(graph), sess => sess.run(normalized)); + return with(tf.Session(graph), sess => sess.run(normalized)); }); } @@ -85,15 +85,14 @@ namespace TensorFlowNET.Examples // get model file string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"; - string zipFile = Path.Join(dir, "inception5h.zip"); - Utility.Web.Download(url, zipFile); + Utility.Web.Download(url, dir, "inception5h.zip"); - Utility.Compress.UnZip(zipFile, dir); + Utility.Compress.UnZip(Path.Join(dir, "inception5h.zip"), dir); // download sample picture - string pic = Path.Join(dir, "img", "grace_hopper.jpg"); Directory.CreateDirectory(Path.Join(dir, "img")); - Utility.Web.Download($"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/grace_hopper.jpg", pic); + url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/grace_hopper.jpg"; + Utility.Web.Download(url, Path.Join(dir, "img"), "grace_hopper.jpg"); } } } diff --git a/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs b/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs index 3cd0dbda..75be7738 100644 --- a/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs +++ b/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs @@ -46,7 +46,7 @@ namespace TensorFlowNET.Examples var input_operation = graph.get_operation_by_name(input_name); var output_operation = graph.get_operation_by_name(output_name); - var results = with(tf.Session(graph), + var results = with(tf.Session(graph), sess => sess.run(output_operation.outputs[0], new FeedItem(input_operation.outputs[0], nd))); @@ -68,7 +68,7 @@ namespace TensorFlowNET.Examples int input_mean = 0, int input_std = 255) { - return with(tf.Graph().as_default(), graph => + return with(tf.Graph().as_default(), graph => { var file_reader = tf.read_file(file_name, "file_reader"); var image_reader = tf.image.decode_jpeg(file_reader, channels: 3, name: "jpeg_reader"); @@ -79,7 +79,7 @@ namespace TensorFlowNET.Examples var sub = tf.subtract(bilinear, new float[] { input_mean }); var normalized = tf.divide(sub, new float[] { input_std }); - return with(tf.Session(graph), sess => sess.run(normalized)); + return with(tf.Session(graph), sess => sess.run(normalized)); }); } @@ -90,14 +90,14 @@ namespace TensorFlowNET.Examples // get model file string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz"; - string zipFile = Path.Join(dir, $"{pbFile}.tar.gz"); - Utility.Web.Download(url, zipFile); + Utility.Web.Download(url, dir, $"{pbFile}.tar.gz"); - Utility.Compress.ExtractTGZ(zipFile, dir); + Utility.Compress.ExtractTGZ(Path.Join(dir, $"{pbFile}.tar.gz"), dir); // download sample picture string pic = "grace_hopper.jpg"; - Utility.Web.Download($"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}", Path.Join(dir, pic)); + url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}"; + Utility.Web.Download(url, dir, pic); } } } diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index bac69b76..35d64e5c 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -16,7 +16,7 @@ namespace TensorFlowNET.Examples // Parameters float learning_rate = 0.01f; - int training_epochs = 10000; + int training_epochs = 1000; int display_step = 50; public void Run() @@ -53,7 +53,7 @@ namespace TensorFlowNET.Examples var init = tf.global_variables_initializer(); // Start training - with(tf.Session(), sess => + with(tf.Session(), sess => { // Run the initializer sess.run(init); diff --git a/test/TensorFlowNET.Examples/MetaGraph.cs b/test/TensorFlowNET.Examples/MetaGraph.cs index bb272a2c..7ce74ffc 100644 --- a/test/TensorFlowNET.Examples/MetaGraph.cs +++ b/test/TensorFlowNET.Examples/MetaGraph.cs @@ -16,7 +16,7 @@ namespace TensorFlowNET.Examples private void ImportMetaGraph(string dir) { - with(tf.Session(), sess => + with(tf.Session(), sess => { var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); new_saver.restore(sess, dir + "my-model-10000"); diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 28952543..2ec7d20b 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -5,10 +5,16 @@ netcoreapp2.2 + + + + + + + - diff --git a/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs b/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs new file mode 100644 index 00000000..f705b98c --- /dev/null +++ b/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs @@ -0,0 +1,91 @@ +using NumSharp.Core; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Text.RegularExpressions; + +namespace TensorFlowNET.Examples.CnnTextClassification +{ + public class DataHelpers + { + private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; + private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; + + public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len) + { + string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} "; + /*if (step == "train") + df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/ + var char_dict = new Dictionary(); + char_dict[""] = 0; + char_dict[""] = 1; + foreach (char c in alphabet) + char_dict[c.ToString()] = char_dict.Count; + + var contents = File.ReadAllLines(TRAIN_PATH); + + var x = new int[contents.Length][]; + var y = new int[contents.Length]; + for (int i = 0; i < contents.Length; i++) + { + string[] parts = contents[i].ToLower().Split(",\"").ToArray(); + string content = parts[2]; + content = content.Substring(0, content.Length - 1); + x[i] = new int[document_max_len]; + for (int j = 0; j < document_max_len; j++) + { + if (j >= content.Length) + x[i][j] = char_dict[""]; + else + x[i][j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict[""]; + } + + y[i] = int.Parse(parts[0]); + } + + return (x, y, alphabet.Length + 2); + } + + /// + /// Loads MR polarity data from files, splits the data into words and generates labels. + /// Returns split sentences and labels. + /// + /// + /// + /// + public static (string[], NDArray) load_data_and_labels(string positive_data_file, string negative_data_file) + { + Directory.CreateDirectory("CnnTextClassification"); + Utility.Web.Download(positive_data_file, "CnnTextClassification", "rt -polarity.pos"); + Utility.Web.Download(negative_data_file, "CnnTextClassification", "rt-polarity.neg"); + + // Load data from files + var positive_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.pos") + .Select(x => x.Trim()) + .ToArray(); + + var negative_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.neg") + .Select(x => x.Trim()) + .ToArray(); + + var x_text = new List(); + x_text.AddRange(positive_examples); + x_text.AddRange(negative_examples); + x_text = x_text.Select(x => clean_str(x)).ToList(); + + var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray(); + var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray(); + var y = np.array(1);// np.concatenate(new int[][][] { positive_labels, negative_labels }); + return (x_text.ToArray(), y); + } + + private static string clean_str(string str) + { + str = Regex.Replace(str, @"[^A-Za-z0-9(),!?\'\`]", " "); + str = Regex.Replace(str, @"\'s", " \'s"); + return str; + } + } +} diff --git a/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs new file mode 100644 index 00000000..4ea583d4 --- /dev/null +++ b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs @@ -0,0 +1,37 @@ +using NumSharp.Core; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Tensorflow; +using TensorFlowNET.Examples.Utility; + +namespace TensorFlowNET.Examples.CnnTextClassification +{ + /// + /// https://github.com/dongjun-Lee/text-classification-models-tf + /// + public class TextClassificationTrain : Python, IExample + { + private string dataDir = "text_classification"; + private string dataFileName = "dbpedia_csv.tar.gz"; + + private const int CHAR_MAX_LEN = 1014; + + public void Run() + { + download_dbpedia(); + Console.WriteLine("Building dataset..."); + var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN); + //var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15); + } + + public void download_dbpedia() + { + string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; + Web.Download(url, dataDir, dataFileName); + Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); + } + } +} diff --git a/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs b/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs index b57da319..cd59287b 100644 --- a/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs +++ b/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs @@ -46,9 +46,8 @@ namespace TensorFlowNET.Examples // get model file string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}"; - string zipFile = Path.Join(dir, $"imdb.zip"); - Utility.Web.Download(url, zipFile); - Utility.Compress.UnZip(zipFile, dir); + Utility.Web.Download(url, dir, "imdb.zip"); + Utility.Compress.UnZip(Path.Join(dir, $"imdb.zip"), dir); // prepare training dataset var x_train = ReadData(Path.Join(dir, "x_train.txt")); diff --git a/src/TensorFlowNET.Utility/Compress.cs b/test/TensorFlowNET.Examples/Utility/Compress.cs similarity index 98% rename from src/TensorFlowNET.Utility/Compress.cs rename to test/TensorFlowNET.Examples/Utility/Compress.cs index 0efe59bd..cf40e2c4 100644 --- a/src/TensorFlowNET.Utility/Compress.cs +++ b/test/TensorFlowNET.Examples/Utility/Compress.cs @@ -7,7 +7,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -namespace TensorFlowNET.Utility +namespace TensorFlowNET.Examples.Utility { public class Compress { diff --git a/test/TensorFlowNET.Examples/Utility/Web.cs b/test/TensorFlowNET.Examples/Utility/Web.cs new file mode 100644 index 00000000..e2155e93 --- /dev/null +++ b/test/TensorFlowNET.Examples/Utility/Web.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace TensorFlowNET.Examples.Utility +{ + public class Web + { + public static bool Download(string url, string destDir, string destFileName) + { + if (destFileName == null) + destFileName = url.Split(Path.DirectorySeparatorChar).Last(); + + Directory.CreateDirectory(destDir); + + string relativeFilePath = Path.Combine(destDir, destFileName); + + if (File.Exists(relativeFilePath)) + { + Console.WriteLine($"{relativeFilePath} already exists."); + return false; + } + + var wc = new WebClient(); + Console.WriteLine($"Downloading {relativeFilePath}"); + var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); + while (!download.IsCompleted) + { + Thread.Sleep(1000); + Console.Write("."); + } + Console.WriteLine(""); + Console.WriteLine($"Downloaded {relativeFilePath}"); + + return true; + } + } +} diff --git a/test/TensorFlowNET.UnitTest/CApiTest.cs b/test/TensorFlowNET.UnitTest/CApiTest.cs index dedb88b3..4f0be55d 100644 --- a/test/TensorFlowNET.UnitTest/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiTest.cs @@ -6,7 +6,7 @@ using Tensorflow; namespace TensorFlowNET.UnitTest { - public class CApiTest + public class CApiTest : Python { protected TF_Code TF_OK = TF_Code.TF_OK; protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index 7b0768e0..ee496e7c 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -10,7 +10,7 @@ using Tensorflow; namespace TensorFlowNET.UnitTest { [TestClass] - public class ConstantTest + public class ConstantTest : Python { Status status = new Status(); @@ -27,7 +27,7 @@ namespace TensorFlowNET.UnitTest { string str = "Hello, TensorFlow.NET!"; var tensor = tf.constant(str); - Python.with(tf.Session(), sess => + with(tf.Session(), sess => { var result = sess.run(tensor); Assert.IsTrue(result.Data()[0] == str); @@ -39,7 +39,7 @@ namespace TensorFlowNET.UnitTest { // small size var tensor = tf.zeros(new Shape(3, 2), TF_DataType.TF_INT32, "small"); - Python.with(tf.Session(), sess => + with(tf.Session(), sess => { var result = sess.run(tensor); @@ -50,7 +50,7 @@ namespace TensorFlowNET.UnitTest // big size tensor = tf.zeros(new Shape(200, 100), TF_DataType.TF_INT32, "big"); - Python.with(tf.Session(), sess => + with(tf.Session(), sess => { var result = sess.run(tensor); @@ -74,7 +74,7 @@ namespace TensorFlowNET.UnitTest }); var tensor = tf.constant(nd); - Python.with(tf.Session(), sess => + with(tf.Session(), sess => { var result = sess.run(tensor); var data = result.Data(); diff --git a/test/TensorFlowNET.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.UnitTest/NameScopeTest.cs index 64b2bd1f..01a5f031 100644 --- a/test/TensorFlowNET.UnitTest/NameScopeTest.cs +++ b/test/TensorFlowNET.UnitTest/NameScopeTest.cs @@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void NestedNameScope() { - with(new ops.name_scope("scope1"), scope1 => + with(new ops.name_scope("scope1"), scope1 => { name = scope1; Assert.AreEqual("scope1", g._name_stack); @@ -24,7 +24,7 @@ namespace TensorFlowNET.UnitTest var const1 = tf.constant(1.0); Assert.AreEqual("scope1/Const:0", const1.name); - with(new ops.name_scope("scope2"), scope2 => + with(new ops.name_scope("scope2"), scope2 => { name = scope2; Assert.AreEqual("scope1/scope2", g._name_stack); diff --git a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs index 3bc6a892..189a9c69 100644 --- a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs +++ b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs @@ -7,7 +7,7 @@ using Tensorflow; namespace TensorFlowNET.UnitTest { [TestClass] - public class PlaceholderTest + public class PlaceholderTest : Python { [TestMethod] public void placeholder() @@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest var x = tf.placeholder(tf.int32); var y = x * 3; - Python.with(tf.Session(), sess => + with(tf.Session(), sess => { var result = sess.run(y, new FeedItem(x, 2)); diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index e93e4444..37206426 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -82,7 +82,7 @@ namespace TensorFlowNET.UnitTest var a = constant_op.constant(np.array(3.0).reshape(1, 1)); var b = constant_op.constant(np.array(2.0).reshape(1, 1)); var c = math_ops.matmul(a, b, name: "matmul"); - Python.with(tf.Session(), delegate + with(tf.Session(), delegate { var result = c.eval(); Assert.AreEqual(6, result.Data()[0]); diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 2ff58949..f495711c 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -1,7 +1,7 @@ - netcoreapp2.1 + netcoreapp2.2 false diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index f5aec32b..3ea4dfc8 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest public void ImportGraph() { - with(tf.Session(), sess => + with(tf.Session(), sess => { var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); }); @@ -44,7 +44,7 @@ namespace TensorFlowNET.UnitTest public void ImportSavedModel() { - with(Session.LoadFromSavedModel("mobilenet"), sess => + with(Session.LoadFromSavedModel("mobilenet"), sess => { }); @@ -65,7 +65,7 @@ namespace TensorFlowNET.UnitTest // Add ops to save and restore all the variables. var saver = tf.train.Saver(); - with(tf.Session(), sess => + with(tf.Session(), sess => { sess.run(init_op); diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 88c65bd3..7713e774 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -30,6 +30,47 @@ namespace TensorFlowNET.UnitTest var mammal2 = tf.Variable("Tiger"); } + /// + /// https://www.tensorflow.org/api_docs/python/tf/variable_scope + /// how to create a new variable + /// + [TestMethod] + public void VarCreation() + { + with(tf.variable_scope("foo"), delegate + { + with(tf.variable_scope("bar"), delegate + { + var v = tf.get_variable("v", new TensorShape(1)); + Assert.AreEqual(v.name, "foo/bar/v:0"); + }); + }); + } + + /// + /// how to reenter a premade variable scope safely + /// + [TestMethod] + public void ReenterVariableScope() + { + variable_scope vs = null; + with(tf.variable_scope("foo"), v => vs = v); + + // Re-enter the variable scope. + with(tf.variable_scope(vs, auxiliary_name_scope: false), v => + { + var vs1 = (VariableScope)v; + // Restore the original name_scope. + with(tf.name_scope(vs1.original_name_scope), delegate + { + var v1 = tf.get_variable("v", new TensorShape(1)); + Assert.AreEqual(v1.name, "foo/v:0"); + var c1 = tf.constant(new int[] { 1 }, name: "c"); + Assert.AreEqual(c1.name, "foo/c:0"); + }); + }); + } + [TestMethod] public void ScalarVar() { @@ -49,7 +90,7 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void Assign1() { - with(tf.Graph().as_default(), graph => + with(tf.Graph().as_default(), graph => { var variable = tf.Variable(31, name: "tree"); var init = tf.global_variables_initializer(); @@ -75,7 +116,7 @@ namespace TensorFlowNET.UnitTest // Add an op to initialize the variables. var init_op = tf.global_variables_initializer(); - with(tf.Session(), sess => + with(tf.Session(), sess => { sess.run(init_op); // o some work with the model.