diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index 25a97e6d..4f866b47 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -11,9 +11,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{E8340C61-12C1-4BEE-A340-403E7C1ACD82}"
-EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "scikit-learn", "..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj", "{199DDAD8-4A6F-43B3-A560-C0393619E304}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -37,14 +35,10 @@ Global
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.Build.0 = Debug|Any CPU
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.Build.0 = Release|Any CPU
- {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.Build.0 = Release|Any CPU
- {199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.Build.0 = Release|Any CPU
+ {1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs
index 19876d62..4863b510 100644
--- a/src/TensorFlowNET.Core/APIs/tf.init.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.init.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Operations.Initializers;
namespace Tensorflow
{
@@ -24,128 +25,5 @@ namespace Tensorflow
default_name,
values,
auxiliary_name_scope);
-
- public class Zeros : IInitializer
- {
- private TF_DataType dtype;
-
- public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT)
- {
- this.dtype = dtype;
- }
-
- public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
- {
- if (dtype == TF_DataType.DtInvalid)
- dtype = this.dtype;
-
- return array_ops.zeros(shape, dtype);
- }
-
- public object get_config()
- {
- return new { dtype = dtype.name() };
- }
- }
-
- ///
- /// Initializer capable of adapting its scale to the shape of weights tensors.
- ///
- public class VarianceScaling : IInitializer
- {
- protected float _scale;
- protected string _mode;
- protected string _distribution;
- protected int? _seed;
- protected TF_DataType _dtype;
-
- public VarianceScaling(float scale = 1.0f,
- string mode = "fan_in",
- string distribution= "truncated_normal",
- int? seed = null,
- TF_DataType dtype = TF_DataType.TF_FLOAT)
- {
- if (scale < 0)
- throw new ValueError("`scale` must be positive float.");
- _scale = scale;
- _mode = mode;
- _distribution = distribution;
- _seed = seed;
- _dtype = dtype;
- }
-
- public Tensor call(TensorShape shape, TF_DataType dtype)
- {
- var (fan_in, fan_out) = _compute_fans(shape);
- if (_mode == "fan_in")
- _scale /= Math.Max(1, fan_in);
- else if (_mode == "fan_out")
- _scale /= Math.Max(1, fan_out);
- else
- _scale /= Math.Max(1, (fan_in + fan_out) / 2);
-
- if (_distribution == "normal" || _distribution == "truncated_normal")
- {
- throw new NotImplementedException("truncated_normal");
- }
- else if(_distribution == "untruncated_normal")
- {
- throw new NotImplementedException("truncated_normal");
- }
- else
- {
- var limit = Math.Sqrt(3.0f * _scale);
- return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
- }
- }
-
- private (int, int) _compute_fans(int[] shape)
- {
- if (shape.Length < 1)
- return (1, 1);
- if (shape.Length == 1)
- return (shape[0], shape[0]);
- if (shape.Length == 2)
- return (shape[0], shape[1]);
- else
- throw new NotImplementedException("VarianceScaling._compute_fans");
- }
-
- public virtual object get_config()
- {
- return new
- {
- scale = _scale,
- mode = _mode,
- distribution = _distribution,
- seed = _seed,
- dtype = _dtype
- };
- }
- }
-
- public class GlorotUniform : VarianceScaling
- {
- public GlorotUniform(float scale = 1.0f,
- string mode = "fan_avg",
- string distribution = "uniform",
- int? seed = null,
- TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
- {
-
- }
-
- public object get_config()
- {
- return new
- {
- scale = _scale,
- mode = _mode,
- distribution = _distribution,
- seed = _seed,
- dtype = _dtype
- };
- }
- }
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs
index f8eff7e7..3e273e61 100644
--- a/src/TensorFlowNET.Core/APIs/tf.random.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.random.cs
@@ -22,5 +22,12 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name);
+
+ public static Tensor random_uniform(int[] shape,
+ float minval = 0,
+ float? maxval = null,
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ int? seed = null,
+ string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Initializers.cs b/src/TensorFlowNET.Core/Keras/Initializers.cs
new file mode 100644
index 00000000..cea77ae9
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Initializers.cs
@@ -0,0 +1,20 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Operations.Initializers;
+
+namespace Tensorflow.Keras
+{
+ public class Initializers
+ {
+ ///
+ /// He normal initializer.
+ ///
+ ///
+ ///
+ public IInitializer he_normal(int? seed = null)
+ {
+ return new VarianceScaling(scale: 20f, mode: "fan_in", distribution: "truncated_normal", seed: seed);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/tf.keras.cs b/src/TensorFlowNET.Core/Keras/tf.keras.cs
new file mode 100644
index 00000000..73b8e0a0
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/tf.keras.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras;
+
+namespace Tensorflow
+{
+ public static partial class tf
+ {
+ public static class keras
+ {
+ public static Initializers initializers => new Initializers();
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
new file mode 100644
index 00000000..5d905583
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
@@ -0,0 +1,30 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations.Initializers
+{
+ public class GlorotUniform : VarianceScaling
+ {
+ public GlorotUniform(float scale = 1.0f,
+ string mode = "fan_avg",
+ string distribution = "uniform",
+ int? seed = null,
+ TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
+ {
+
+ }
+
+ public object get_config()
+ {
+ return new
+ {
+ scale = _scale,
+ mode = _mode,
+ distribution = _distribution,
+ seed = _seed,
+ dtype = _dtype
+ };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
similarity index 67%
rename from src/TensorFlowNET.Core/Operations/IInitializer.cs
rename to src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
index 6382e3e0..422bf95d 100644
--- a/src/TensorFlowNET.Core/Operations/IInitializer.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
@@ -6,7 +6,7 @@ namespace Tensorflow
{
public interface IInitializer
{
- Tensor call(TensorShape shape, TF_DataType dtype);
+ Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid);
object get_config();
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
new file mode 100644
index 00000000..4c0a7cee
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
@@ -0,0 +1,41 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations.Initializers
+{
+ public class TruncatedNormal : IInitializer
+ {
+ private float mean;
+ private float stddev;
+ private int? seed;
+ private TF_DataType dtype;
+
+ public TruncatedNormal(float mean = 0.0f,
+ float stddev = 1.0f,
+ int? seed = null,
+ TF_DataType dtype = TF_DataType.TF_FLOAT)
+ {
+ this.mean = mean;
+ this.stddev = stddev;
+ this.seed = seed;
+ this.dtype = dtype;
+ }
+
+ public Tensor call(TensorShape shape, TF_DataType dtype)
+ {
+ throw new NotImplementedException("");
+ }
+
+ public object get_config()
+ {
+ return new
+ {
+ mean = mean,
+ stddev = stddev,
+ seed = seed,
+ dtype = dtype.name()
+ };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
new file mode 100644
index 00000000..0fcaf392
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
@@ -0,0 +1,82 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations.Initializers
+{
+ ///
+ /// Initializer capable of adapting its scale to the shape of weights tensors.
+ ///
+ public class VarianceScaling : IInitializer
+ {
+ protected float _scale;
+ protected string _mode;
+ protected string _distribution;
+ protected int? _seed;
+ protected TF_DataType _dtype;
+
+ public VarianceScaling(float scale = 1.0f,
+ string mode = "fan_in",
+ string distribution = "truncated_normal",
+ int? seed = null,
+ TF_DataType dtype = TF_DataType.TF_FLOAT)
+ {
+ if (scale < 0)
+ throw new ValueError("`scale` must be positive float.");
+ _scale = scale;
+ _mode = mode;
+ _distribution = distribution;
+ _seed = seed;
+ _dtype = dtype;
+ }
+
+ public Tensor call(TensorShape shape, TF_DataType dtype)
+ {
+ var (fan_in, fan_out) = _compute_fans(shape);
+ if (_mode == "fan_in")
+ _scale /= Math.Max(1, fan_in);
+ else if (_mode == "fan_out")
+ _scale /= Math.Max(1, fan_out);
+ else
+ _scale /= Math.Max(1, (fan_in + fan_out) / 2);
+
+ if (_distribution == "normal" || _distribution == "truncated_normal")
+ {
+ throw new NotImplementedException("truncated_normal");
+ }
+ else if (_distribution == "untruncated_normal")
+ {
+ throw new NotImplementedException("truncated_normal");
+ }
+ else
+ {
+ var limit = Math.Sqrt(3.0f * _scale);
+ return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
+ }
+ }
+
+ private (int, int) _compute_fans(int[] shape)
+ {
+ if (shape.Length < 1)
+ return (1, 1);
+ if (shape.Length == 1)
+ return (shape[0], shape[0]);
+ if (shape.Length == 2)
+ return (shape[0], shape[1]);
+ else
+ throw new NotImplementedException("VarianceScaling._compute_fans");
+ }
+
+ public virtual object get_config()
+ {
+ return new
+ {
+ scale = _scale,
+ mode = _mode,
+ distribution = _distribution,
+ seed = _seed,
+ dtype = _dtype
+ };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
new file mode 100644
index 00000000..ca1f42df
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
@@ -0,0 +1,29 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations.Initializers
+{
+ public class Zeros : IInitializer
+ {
+ private TF_DataType dtype;
+
+ public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT)
+ {
+ this.dtype = dtype;
+ }
+
+ public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
+ {
+ if (dtype == TF_DataType.DtInvalid)
+ dtype = this.dtype;
+
+ return array_ops.zeros(shape, dtype);
+ }
+
+ public object get_config()
+ {
+ return new { dtype = dtype.name() };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/embedding_ops.cs b/src/TensorFlowNET.Core/Operations/embedding_ops.cs
new file mode 100644
index 00000000..27fba04b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/embedding_ops.cs
@@ -0,0 +1,14 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class embedding_ops
+ {
+ public Tensor _embedding_lookup_and_transform()
+ {
+ throw new NotImplementedException("");
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 949166c1..1e4f2065 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -43,7 +43,7 @@ Fixed import name scope issue.
-
+
diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs
index 29c03c19..c3453744 100644
--- a/src/TensorFlowNET.Core/Variables/VariableScope.cs
+++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs
@@ -33,7 +33,7 @@ namespace Tensorflow
string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
- IInitializer initializer = null,
+ object initializer = null, // IInitializer or Tensor
bool? trainable = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation= VariableAggregation.NONE)
diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs
index 5bd8c86d..9f067bf8 100644
--- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs
+++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs
@@ -23,7 +23,7 @@ namespace Tensorflow
public RefVariable get_variable(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
- IInitializer initializer = null,
+ object initializer = null, // IInitializer or Tensor
bool? trainable = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
@@ -45,7 +45,7 @@ namespace Tensorflow
private RefVariable _true_getter(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
- IInitializer initializer = null,
+ object initializer = null,
bool? trainable = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
@@ -53,14 +53,32 @@ namespace Tensorflow
{
bool is_scalar = shape.NDim == 0;
- return _get_single_variable(name: name,
- shape: shape,
+ if (initializer is IInitializer init)
+ {
+ return _get_single_variable(name: name,
+ shape: shape,
dtype: dtype,
- initializer: initializer,
+ initializer: init,
trainable: trainable,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
+ }
+ else if (initializer is Tensor tensor)
+ {
+ return _get_single_variable(name: name,
+ shape: shape,
+ dtype: dtype,
+ initializer: tensor,
+ trainable: trainable,
+ validate_shape: validate_shape,
+ synchronization: synchronization,
+ aggregation: aggregation);
+ }
+ else
+ {
+ throw new NotImplementedException("_true_getter");
+ }
}
private RefVariable _get_single_variable(string name,
@@ -125,5 +143,45 @@ namespace Tensorflow
return v;
}
+
+ private RefVariable _get_single_variable(string name,
+ TensorShape shape = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ Tensor initializer = null,
+ bool reuse = false,
+ bool? trainable = null,
+ bool validate_shape = false,
+ bool? use_resource = null,
+ VariableSynchronization synchronization = VariableSynchronization.AUTO,
+ VariableAggregation aggregation = VariableAggregation.NONE)
+ {
+ if (use_resource == null)
+ use_resource = false;
+
+ if (_vars.ContainsKey(name))
+ {
+ if (!reuse)
+ {
+ var var = _vars[name];
+
+ }
+ throw new NotImplementedException("_get_single_variable");
+ }
+
+ RefVariable v = null;
+ // Create the variable.
+ ops.init_scope();
+ {
+ var init_val = initializer;
+ v = new RefVariable(init_val,
+ name: name,
+ validate_shape: validate_shape,
+ trainable: trainable.Value);
+ }
+
+ _vars[name] = v;
+
+ return v;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Variables/tf.variable.cs b/src/TensorFlowNET.Core/Variables/tf.variable.cs
index b61b558e..aac58a9a 100644
--- a/src/TensorFlowNET.Core/Variables/tf.variable.cs
+++ b/src/TensorFlowNET.Core/Variables/tf.variable.cs
@@ -15,7 +15,7 @@ namespace Tensorflow
public static RefVariable get_variable(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
- IInitializer initializer = null,
+ object initializer = null, // IInitializer or Tensor
bool? trainable = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs
index eeb415f5..e1e1331d 100644
--- a/src/TensorFlowNET.Core/tf.cs
+++ b/src/TensorFlowNET.Core/tf.cs
@@ -12,20 +12,27 @@ namespace Tensorflow
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float32 = TF_DataType.TF_FLOAT;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
+ public static TF_DataType boolean = TF_DataType.TF_BOOL;
public static TF_DataType chars = TF_DataType.TF_STRING;
public static Context context = new Context(new ContextOptions(), new Status());
public static Session defaultSession;
- public static RefVariable Variable(T data, string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
+ public static RefVariable Variable(T data,
+ bool trainable = true,
+ string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid)
{
- return Tensorflow.variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid);
+ return Tensorflow.variable_scope.default_variable_creator(data,
+ trainable: trainable,
+ name: name,
+ dtype: TF_DataType.DtInvalid);
}
- public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null)
+ public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)
{
- return gen_array_ops.placeholder(dtype, shape);
+ return gen_array_ops.placeholder(dtype, shape, name);
}
public static void enable_eager_execution()
diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
index 12a92226..c6c052fb 100644
--- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
+++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
@@ -13,7 +13,6 @@
-
diff --git a/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs b/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
index f705b98c..586a978a 100644
--- a/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
+++ b/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
@@ -77,7 +77,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray();
var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray();
- var y = np.array(1);// np.concatenate(new int[][][] { positive_labels, negative_labels });
+ var y = np.concatenate(new int[][][] { positive_labels, negative_labels });
return (x_text.ToArray(), y);
}
diff --git a/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
index 4ea583d4..e6e26ba6 100644
--- a/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
+++ b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
@@ -5,6 +5,7 @@ using System.IO;
using System.Linq;
using System.Text;
using Tensorflow;
+using TensorFlowNET.Examples.TextClassification;
using TensorFlowNET.Examples.Utility;
namespace TensorFlowNET.Examples.CnnTextClassification
@@ -18,13 +19,21 @@ namespace TensorFlowNET.Examples.CnnTextClassification
private string dataFileName = "dbpedia_csv.tar.gz";
private const int CHAR_MAX_LEN = 1014;
+ private const int NUM_CLASS = 2;
public void Run()
{
download_dbpedia();
Console.WriteLine("Building dataset...");
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN);
- //var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15);
+
+ var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
+
+ with(tf.Session(), sess =>
+ {
+ new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
+
+ });
}
public void download_dbpedia()
@@ -33,5 +42,38 @@ namespace TensorFlowNET.Examples.CnnTextClassification
Web.Download(url, dataDir, dataFileName);
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
}
+
+ private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
+ {
+ int len = x.Length;
+ int classes = y.Distinct().Count();
+ int samples = len / classes;
+ int train_size = int.Parse((samples * (1 - test_size)).ToString());
+
+ var train_x = new List();
+ var valid_x = new List();
+ var train_y = new List();
+ var valid_y = new List();
+
+ for (int i = 0; i< classes; i++)
+ {
+ for (int j = 0; j < samples; j++)
+ {
+ int idx = i * samples + j;
+ if (idx < train_size + samples * i)
+ {
+ train_x.Add(x[idx]);
+ train_y.Add(y[idx]);
+ }
+ else
+ {
+ valid_x.Add(x[idx]);
+ valid_y.Add(y[idx]);
+ }
+ }
+ }
+
+ return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray());
+ }
}
}
diff --git a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs
new file mode 100644
index 00000000..cbdcecee
--- /dev/null
+++ b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs
@@ -0,0 +1,44 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow;
+
+namespace TensorFlowNET.Examples.TextClassification
+{
+ public class VdCnn : Python
+ {
+ private int embedding_size;
+ private int[] filter_sizes;
+ private int[] num_filters;
+ private int[] num_blocks;
+ private float learning_rate;
+ private IInitializer cnn_initializer;
+ private Tensor x;
+ private Tensor y;
+ private Tensor is_training;
+ private RefVariable global_step;
+ private RefVariable embeddings;
+ private Tensor x_emb;
+
+ public VdCnn(int alphabet_size, int document_max_len, int num_class)
+ {
+ embedding_size = 16;
+ filter_sizes = new int[] { 3, 3, 3, 3, 3 };
+ num_filters = new int[] { 64, 64, 128, 256, 512 };
+ num_blocks = new int[] { 2, 2, 2, 2 };
+ learning_rate = 0.001f;
+ cnn_initializer = tf.keras.initializers.he_normal();
+ x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
+ y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
+ is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training");
+ global_step = tf.Variable(0, trainable: false);
+
+ with(tf.name_scope("embedding"), delegate
+ {
+ var init_embeddings = tf.random_uniform(new int[] { alphabet_size, embedding_size }, -1.0f, 1.0f);
+ embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
+ // x_emb = tf.nn.embedding_lookup(embeddings, x);
+ });
+ }
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
index f495711c..eaf16ef8 100644
--- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
+++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
@@ -16,14 +16,15 @@
-
+
-
+
+
diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs
index 7713e774..8254b774 100644
--- a/test/TensorFlowNET.UnitTest/VariableTest.cs
+++ b/test/TensorFlowNET.UnitTest/VariableTest.cs
@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void StringVar()
{
- var mammal1 = tf.Variable("Elephant", "var1", tf.chars);
+ var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.chars);
var mammal2 = tf.Variable("Tiger");
}