diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index ebbefb44..aba22c21 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.11.0 + 0.11.1 Haiping Chen, Meinrad Recheis SciSharp STACK true @@ -17,15 +17,16 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.11.0.0 + 0.11.1.0 Changes since v0.10.0: 1. Upgrade NumSharp to v0.20. 2. Add DisposableObject class to manage object lifetime. 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. 4. Change tensorflow to non-static class in order to execute some initialization process. -5. Overloade session.run(), make syntax simpler. +5. Overload session.run(), make syntax simpler. +6. Add Local Response Normalization. 7.3 - 0.11.0.0 + 0.11.1.0 LICENSE true true diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index da873722..7a5fd60a 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -68,9 +68,9 @@ namespace Tensorflow return defaultSession; } - public Session Session(Graph graph) + public Session Session(Graph graph, SessionOptions opts = null) { - return new Session(graph); + return new Session(graph, opts: opts); } public Session Session(SessionOptions opts) diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs index 3277e67e..482280ca 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs @@ -1,17 +1,54 @@ -using System; +using NumSharp; +using System; using System.Collections.Generic; +using System.IO; using System.Text; +using static Tensorflow.Binding; namespace TensorFlowNET.Examples.ImageProcessing.YOLO { public class Dataset { string annot_path; + int[] input_sizes; + int batch_size; + bool data_aug; + int[] train_input_sizes; + NDArray strides; + NDArray anchors; + Dictionary classes; + int num_classes; + int anchor_per_scale; + int max_bbox_per_scale; + string[] annotations; + int num_samples; + int batch_count; + public int Length = 0; public Dataset(string dataset_type, Config cfg) { annot_path = dataset_type == "train" ? cfg.TRAIN.ANNOT_PATH : cfg.TEST.ANNOT_PATH; + input_sizes = dataset_type == "train" ? cfg.TRAIN.INPUT_SIZE : cfg.TEST.INPUT_SIZE; + batch_size = dataset_type == "train" ? cfg.TRAIN.BATCH_SIZE : cfg.TEST.BATCH_SIZE; + data_aug = dataset_type == "train" ? cfg.TRAIN.DATA_AUG : cfg.TEST.DATA_AUG; + train_input_sizes = cfg.TRAIN.INPUT_SIZE; + strides = np.array(cfg.YOLO.STRIDES); + + classes = Utils.read_class_names(cfg.YOLO.CLASSES); + num_classes = classes.Count; + anchors = np.array(Utils.get_anchors(cfg.YOLO.ANCHORS)); + anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; + max_bbox_per_scale = 150; + + annotations = load_annotations(); + num_samples = len(annotations); + batch_count = 0; + } + + string[] load_annotations() + { + return File.ReadAllLines(annot_path); } } } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs index 15f162b8..57770866 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs @@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO /// public class Main : IExample { - public bool Enabled { get; set; } = false; + public bool Enabled { get; set; } = true; public bool IsImportingGraph { get; set; } = false; public string Name => "YOLOv3"; @@ -41,7 +41,10 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO Tensor true_sbboxes; Tensor true_mbboxes; Tensor true_lbboxes; - Tensor trainable; + Tensor trainable; + + Session sess; + YOLOv3 model; #endregion public bool Run() @@ -50,7 +53,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); - using (var sess = tf.Session(graph)) + var options = new SessionOptions(); + options.SetConfig(new ConfigProto { AllowSoftPlacement = true }); + using (var sess = tf.Session(graph, opts: options)) { Train(sess); } @@ -86,7 +91,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO tf_with(tf.name_scope("define_loss"), scope => { - //model = new YOLOv3(input_data, trainable); + model = new YOLOv3(cfg, input_data, trainable); }); return graph; @@ -109,9 +114,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO string dataDir = Path.Combine(Name, "data"); Directory.CreateDirectory(dataDir); - classes = new Dictionary(); - foreach (var line in File.ReadAllLines(cfg.YOLO.CLASSES)) - classes[classes.Count] = line; + classes = Utils.read_class_names(cfg.YOLO.CLASSES); num_classes = classes.Count; learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT; diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs new file mode 100644 index 00000000..3a0d3089 --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs @@ -0,0 +1,27 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; + +namespace TensorFlowNET.Examples.ImageProcessing.YOLO +{ + class Utils + { + public static Dictionary read_class_names(string file) + { + var classes = new Dictionary(); + foreach (var line in File.ReadAllLines(file)) + classes[classes.Count] = line; + return classes; + } + + public static NDArray get_anchors(string file) + { + return np.array(File.ReadAllText(file).Split(',') + .Select(x => float.Parse(x)) + .ToArray()).reshape(3, 3, 2); + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs index 3036cc0d..5125c603 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs @@ -1,10 +1,50 @@ -using System; +using NumSharp; +using System; using System.Collections.Generic; using System.Text; +using Tensorflow; +using static Tensorflow.Binding; namespace TensorFlowNET.Examples.ImageProcessing.YOLO { public class YOLOv3 { + Config cfg; + Tensor trainable; + Tensor input_data; + Dictionary classes; + int num_class; + NDArray strides; + NDArray anchors; + int anchor_per_scale; + float iou_loss_thresh; + string upsample_method; + Tensor conv_lbbox; + Tensor conv_mbbox; + Tensor conv_sbbox; + + public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_) + { + cfg = cfg_; + input_data = input_data_; + trainable = trainable_; + classes = Utils.read_class_names(cfg.YOLO.CLASSES); + num_class = len(classes); + strides = np.array(cfg.YOLO.STRIDES); + anchors = Utils.get_anchors(cfg.YOLO.ANCHORS); + anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; + iou_loss_thresh = cfg.YOLO.IOU_LOSS_THRESH; + upsample_method = cfg.YOLO.UPSAMPLE_METHOD; + + (conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data); + } + + private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data) + { + Tensor route_1, route_2; + (route_1, route_2, input_data) = backbone.darknet53(input_data, trainable); + + return (conv_lbbox, conv_mbbox, conv_sbbox); + } } } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs new file mode 100644 index 00000000..0e7b1446 --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.Examples.ImageProcessing.YOLO +{ + class backbone + { + public static (Tensor, Tensor, Tensor) darknet53(Tensor input_data, Tensor trainable) + { + return tf_with(tf.variable_scope("darknet"), scope => + { + input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 3, 32 }, trainable: trainable, name: "conv0"); + input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 32, 64 }, trainable: trainable, name: "conv1", downsample: true); + + foreach (var i in range(1)) + input_data = common.residual_block(input_data, 64, 32, 64, trainable: trainable, name: $"residual{i + 0}"); + + var route_1 = input_data; + var route_2 = input_data; + + return (route_1, route_2, input_data); + }); + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs new file mode 100644 index 00000000..57105aa1 --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.Examples.ImageProcessing.YOLO +{ + class common + { + public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tensor trainable, + string name, bool downsample = false, bool activate = true, + bool bn = true) + { + return tf_with(tf.variable_scope(name), scope => + { + int[] strides; + string padding; + + if (downsample) + { + throw new NotImplementedException(""); + } + else + { + strides = new int[] { 1, 1, 1, 1 }; + padding = "SAME"; + } + + var weight = tf.get_variable(name: "weight", dtype: tf.float32, trainable: true, + shape: filters_shape, initializer: tf.random_normal_initializer(stddev: 0.01f)); + + var conv = tf.nn.conv2d(input: input_data, filter: weight, strides: strides, padding: padding); + + if (bn) + { + conv = tf.layers.batch_normalization(conv, beta_initializer: tf.zeros_initializer, + gamma_initializer: tf.ones_initializer, + moving_mean_initializer: tf.zeros_initializer, + moving_variance_initializer: tf.ones_initializer, training: trainable); + } + else + { + throw new NotImplementedException(""); + } + + if (activate) + conv = tf.nn.leaky_relu(conv, alpha: 0.1f); + + return conv; + }); + } + + public static Tensor residual_block(Tensor input_data, int input_channel, int filter_num1, + int filter_num2, Tensor trainable, string name) + { + var short_cut = input_data; + + return tf_with(tf.variable_scope(name), scope => + { + input_data = convolutional(input_data, filters_shape: new int[] { 1, 1, input_channel, filter_num1 }, + trainable: trainable, name: "conv1"); + input_data = convolutional(input_data, filters_shape: new int[] { 3, 3, filter_num1, filter_num2 }, + trainable: trainable, name: "conv2"); + + var residual_output = input_data + short_cut; + + return residual_output; + }); + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs index 99bee6dc..b5c46151 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs @@ -9,12 +9,13 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO { public YoloConfig YOLO; public TrainConfig TRAIN; - public TrainConfig TEST; + public TestConfig TEST; public Config(string root) { YOLO = new YoloConfig(root); TRAIN = new TrainConfig(root); + TEST = new TestConfig(root); } public class YoloConfig @@ -22,13 +23,22 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO string _root; public string CLASSES; + public string ANCHORS; public float MOVING_AVE_DECAY = 0.9995f; public int[] STRIDES = new int[] { 8, 16, 32 }; + public int ANCHOR_PER_SCALE = 3; + public float IOU_LOSS_THRESH = 0.5f; + public string UPSAMPLE_METHOD = "resize"; + public string ORIGINAL_WEIGHT; + public string DEMO_WEIGHT; public YoloConfig(string root) { _root = root; CLASSES = Path.Combine(_root, "data", "classes", "coco.names"); + ANCHORS = Path.Combine(_root, "data", "anchors", "basline_anchors.txt"); + ORIGINAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco.ckpt"); + DEMO_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt"); } } @@ -54,5 +64,31 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); } } + + public class TestConfig + { + string _root; + + public int BATCH_SIZE = 2; + public int[] INPUT_SIZE = new int[] { 544 }; + public bool DATA_AUG = false; + public bool WRITE_IMAGE = true; + public string WRITE_IMAGE_PATH; + public string WEIGHT_FILE; + public bool WRITE_IMAGE_SHOW_LABEL = true; + public bool SHOW_LABEL = true; + public int SECOND_STAGE_EPOCHS = 30; + public float SCORE_THRESHOLD = 0.3f; + public float IOU_THRESHOLD = 0.45f; + public string ANNOT_PATH; + + public TestConfig(string root) + { + _root = root; + ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_test.txt"); + WRITE_IMAGE_PATH = Path.Combine(_root, "data", "detection"); + WEIGHT_FILE = Path.Combine(_root, "checkpoint", "yolov3_test_loss=9.2099.ckpt-5"); + } + } } }