diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index ebbefb44..aba22c21 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -5,7 +5,7 @@
TensorFlow.NET
Tensorflow
1.14.0
- 0.11.0
+ 0.11.1
Haiping Chen, Meinrad Recheis
SciSharp STACK
true
@@ -17,15 +17,16 @@
TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#
Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io
- 0.11.0.0
+ 0.11.1.0
Changes since v0.10.0:
1. Upgrade NumSharp to v0.20.
2. Add DisposableObject class to manage object lifetime.
3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables.
4. Change tensorflow to non-static class in order to execute some initialization process.
-5. Overloade session.run(), make syntax simpler.
+5. Overload session.run(), make syntax simpler.
+6. Add Local Response Normalization.
7.3
- 0.11.0.0
+ 0.11.1.0
LICENSE
true
true
diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs
index da873722..7a5fd60a 100644
--- a/src/TensorFlowNET.Core/tensorflow.cs
+++ b/src/TensorFlowNET.Core/tensorflow.cs
@@ -68,9 +68,9 @@ namespace Tensorflow
return defaultSession;
}
- public Session Session(Graph graph)
+ public Session Session(Graph graph, SessionOptions opts = null)
{
- return new Session(graph);
+ return new Session(graph, opts: opts);
}
public Session Session(SessionOptions opts)
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs
index 3277e67e..482280ca 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs
@@ -1,17 +1,54 @@
-using System;
+using NumSharp;
+using System;
using System.Collections.Generic;
+using System.IO;
using System.Text;
+using static Tensorflow.Binding;
namespace TensorFlowNET.Examples.ImageProcessing.YOLO
{
public class Dataset
{
string annot_path;
+ int[] input_sizes;
+ int batch_size;
+ bool data_aug;
+ int[] train_input_sizes;
+ NDArray strides;
+ NDArray anchors;
+ Dictionary classes;
+ int num_classes;
+ int anchor_per_scale;
+ int max_bbox_per_scale;
+ string[] annotations;
+ int num_samples;
+ int batch_count;
+
public int Length = 0;
public Dataset(string dataset_type, Config cfg)
{
annot_path = dataset_type == "train" ? cfg.TRAIN.ANNOT_PATH : cfg.TEST.ANNOT_PATH;
+ input_sizes = dataset_type == "train" ? cfg.TRAIN.INPUT_SIZE : cfg.TEST.INPUT_SIZE;
+ batch_size = dataset_type == "train" ? cfg.TRAIN.BATCH_SIZE : cfg.TEST.BATCH_SIZE;
+ data_aug = dataset_type == "train" ? cfg.TRAIN.DATA_AUG : cfg.TEST.DATA_AUG;
+ train_input_sizes = cfg.TRAIN.INPUT_SIZE;
+ strides = np.array(cfg.YOLO.STRIDES);
+
+ classes = Utils.read_class_names(cfg.YOLO.CLASSES);
+ num_classes = classes.Count;
+ anchors = np.array(Utils.get_anchors(cfg.YOLO.ANCHORS));
+ anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE;
+ max_bbox_per_scale = 150;
+
+ annotations = load_annotations();
+ num_samples = len(annotations);
+ batch_count = 0;
+ }
+
+ string[] load_annotations()
+ {
+ return File.ReadAllLines(annot_path);
}
}
}
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
index 15f162b8..57770866 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
@@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
///
public class Main : IExample
{
- public bool Enabled { get; set; } = false;
+ public bool Enabled { get; set; } = true;
public bool IsImportingGraph { get; set; } = false;
public string Name => "YOLOv3";
@@ -41,7 +41,10 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
Tensor true_sbboxes;
Tensor true_mbboxes;
Tensor true_lbboxes;
- Tensor trainable;
+ Tensor trainable;
+
+ Session sess;
+ YOLOv3 model;
#endregion
public bool Run()
@@ -50,7 +53,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
- using (var sess = tf.Session(graph))
+ var options = new SessionOptions();
+ options.SetConfig(new ConfigProto { AllowSoftPlacement = true });
+ using (var sess = tf.Session(graph, opts: options))
{
Train(sess);
}
@@ -86,7 +91,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
tf_with(tf.name_scope("define_loss"), scope =>
{
- //model = new YOLOv3(input_data, trainable);
+ model = new YOLOv3(cfg, input_data, trainable);
});
return graph;
@@ -109,9 +114,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
string dataDir = Path.Combine(Name, "data");
Directory.CreateDirectory(dataDir);
- classes = new Dictionary();
- foreach (var line in File.ReadAllLines(cfg.YOLO.CLASSES))
- classes[classes.Count] = line;
+ classes = Utils.read_class_names(cfg.YOLO.CLASSES);
num_classes = classes.Count;
learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT;
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs
new file mode 100644
index 00000000..3a0d3089
--- /dev/null
+++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs
@@ -0,0 +1,27 @@
+using NumSharp;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+
+namespace TensorFlowNET.Examples.ImageProcessing.YOLO
+{
+ class Utils
+ {
+ public static Dictionary read_class_names(string file)
+ {
+ var classes = new Dictionary();
+ foreach (var line in File.ReadAllLines(file))
+ classes[classes.Count] = line;
+ return classes;
+ }
+
+ public static NDArray get_anchors(string file)
+ {
+ return np.array(File.ReadAllText(file).Split(',')
+ .Select(x => float.Parse(x))
+ .ToArray()).reshape(3, 3, 2);
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs
index 3036cc0d..5125c603 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs
@@ -1,10 +1,50 @@
-using System;
+using NumSharp;
+using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow;
+using static Tensorflow.Binding;
namespace TensorFlowNET.Examples.ImageProcessing.YOLO
{
public class YOLOv3
{
+ Config cfg;
+ Tensor trainable;
+ Tensor input_data;
+ Dictionary classes;
+ int num_class;
+ NDArray strides;
+ NDArray anchors;
+ int anchor_per_scale;
+ float iou_loss_thresh;
+ string upsample_method;
+ Tensor conv_lbbox;
+ Tensor conv_mbbox;
+ Tensor conv_sbbox;
+
+ public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_)
+ {
+ cfg = cfg_;
+ input_data = input_data_;
+ trainable = trainable_;
+ classes = Utils.read_class_names(cfg.YOLO.CLASSES);
+ num_class = len(classes);
+ strides = np.array(cfg.YOLO.STRIDES);
+ anchors = Utils.get_anchors(cfg.YOLO.ANCHORS);
+ anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE;
+ iou_loss_thresh = cfg.YOLO.IOU_LOSS_THRESH;
+ upsample_method = cfg.YOLO.UPSAMPLE_METHOD;
+
+ (conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data);
+ }
+
+ private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data)
+ {
+ Tensor route_1, route_2;
+ (route_1, route_2, input_data) = backbone.darknet53(input_data, trainable);
+
+ return (conv_lbbox, conv_mbbox, conv_sbbox);
+ }
}
}
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs
new file mode 100644
index 00000000..0e7b1446
--- /dev/null
+++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow;
+using static Tensorflow.Binding;
+
+namespace TensorFlowNET.Examples.ImageProcessing.YOLO
+{
+ class backbone
+ {
+ public static (Tensor, Tensor, Tensor) darknet53(Tensor input_data, Tensor trainable)
+ {
+ return tf_with(tf.variable_scope("darknet"), scope =>
+ {
+ input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 3, 32 }, trainable: trainable, name: "conv0");
+ input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 32, 64 }, trainable: trainable, name: "conv1", downsample: true);
+
+ foreach (var i in range(1))
+ input_data = common.residual_block(input_data, 64, 32, 64, trainable: trainable, name: $"residual{i + 0}");
+
+ var route_1 = input_data;
+ var route_2 = input_data;
+
+ return (route_1, route_2, input_data);
+ });
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs
new file mode 100644
index 00000000..57105aa1
--- /dev/null
+++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs
@@ -0,0 +1,72 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow;
+using static Tensorflow.Binding;
+
+namespace TensorFlowNET.Examples.ImageProcessing.YOLO
+{
+ class common
+ {
+ public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tensor trainable,
+ string name, bool downsample = false, bool activate = true,
+ bool bn = true)
+ {
+ return tf_with(tf.variable_scope(name), scope =>
+ {
+ int[] strides;
+ string padding;
+
+ if (downsample)
+ {
+ throw new NotImplementedException("");
+ }
+ else
+ {
+ strides = new int[] { 1, 1, 1, 1 };
+ padding = "SAME";
+ }
+
+ var weight = tf.get_variable(name: "weight", dtype: tf.float32, trainable: true,
+ shape: filters_shape, initializer: tf.random_normal_initializer(stddev: 0.01f));
+
+ var conv = tf.nn.conv2d(input: input_data, filter: weight, strides: strides, padding: padding);
+
+ if (bn)
+ {
+ conv = tf.layers.batch_normalization(conv, beta_initializer: tf.zeros_initializer,
+ gamma_initializer: tf.ones_initializer,
+ moving_mean_initializer: tf.zeros_initializer,
+ moving_variance_initializer: tf.ones_initializer, training: trainable);
+ }
+ else
+ {
+ throw new NotImplementedException("");
+ }
+
+ if (activate)
+ conv = tf.nn.leaky_relu(conv, alpha: 0.1f);
+
+ return conv;
+ });
+ }
+
+ public static Tensor residual_block(Tensor input_data, int input_channel, int filter_num1,
+ int filter_num2, Tensor trainable, string name)
+ {
+ var short_cut = input_data;
+
+ return tf_with(tf.variable_scope(name), scope =>
+ {
+ input_data = convolutional(input_data, filters_shape: new int[] { 1, 1, input_channel, filter_num1 },
+ trainable: trainable, name: "conv1");
+ input_data = convolutional(input_data, filters_shape: new int[] { 3, 3, filter_num1, filter_num2 },
+ trainable: trainable, name: "conv2");
+
+ var residual_output = input_data + short_cut;
+
+ return residual_output;
+ });
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs
index 99bee6dc..b5c46151 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs
@@ -9,12 +9,13 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
{
public YoloConfig YOLO;
public TrainConfig TRAIN;
- public TrainConfig TEST;
+ public TestConfig TEST;
public Config(string root)
{
YOLO = new YoloConfig(root);
TRAIN = new TrainConfig(root);
+ TEST = new TestConfig(root);
}
public class YoloConfig
@@ -22,13 +23,22 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
string _root;
public string CLASSES;
+ public string ANCHORS;
public float MOVING_AVE_DECAY = 0.9995f;
public int[] STRIDES = new int[] { 8, 16, 32 };
+ public int ANCHOR_PER_SCALE = 3;
+ public float IOU_LOSS_THRESH = 0.5f;
+ public string UPSAMPLE_METHOD = "resize";
+ public string ORIGINAL_WEIGHT;
+ public string DEMO_WEIGHT;
public YoloConfig(string root)
{
_root = root;
CLASSES = Path.Combine(_root, "data", "classes", "coco.names");
+ ANCHORS = Path.Combine(_root, "data", "anchors", "basline_anchors.txt");
+ ORIGINAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco.ckpt");
+ DEMO_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt");
}
}
@@ -54,5 +64,31 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt");
}
}
+
+ public class TestConfig
+ {
+ string _root;
+
+ public int BATCH_SIZE = 2;
+ public int[] INPUT_SIZE = new int[] { 544 };
+ public bool DATA_AUG = false;
+ public bool WRITE_IMAGE = true;
+ public string WRITE_IMAGE_PATH;
+ public string WEIGHT_FILE;
+ public bool WRITE_IMAGE_SHOW_LABEL = true;
+ public bool SHOW_LABEL = true;
+ public int SECOND_STAGE_EPOCHS = 30;
+ public float SCORE_THRESHOLD = 0.3f;
+ public float IOU_THRESHOLD = 0.45f;
+ public string ANNOT_PATH;
+
+ public TestConfig(string root)
+ {
+ _root = root;
+ ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_test.txt");
+ WRITE_IMAGE_PATH = Path.Combine(_root, "data", "detection");
+ WEIGHT_FILE = Path.Combine(_root, "checkpoint", "yolov3_test_loss=9.2099.ckpt-5");
+ }
+ }
}
}