| @@ -5,7 +5,7 @@ | |||
| <AssemblyName>TensorFlow.NET</AssemblyName> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <TargetTensorFlow>1.14.0</TargetTensorFlow> | |||
| <Version>0.11.0</Version> | |||
| <Version>0.11.1</Version> | |||
| <Authors>Haiping Chen, Meinrad Recheis</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
| @@ -17,15 +17,16 @@ | |||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | |||
| <Description>Google's TensorFlow full binding in .NET Standard. | |||
| Docs: https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.11.0.0</AssemblyVersion> | |||
| <AssemblyVersion>0.11.1.0</AssemblyVersion> | |||
| <PackageReleaseNotes>Changes since v0.10.0: | |||
| 1. Upgrade NumSharp to v0.20. | |||
| 2. Add DisposableObject class to manage object lifetime. | |||
| 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. | |||
| 4. Change tensorflow to non-static class in order to execute some initialization process. | |||
| 5. Overloade session.run(), make syntax simpler.</PackageReleaseNotes> | |||
| 5. Overload session.run(), make syntax simpler. | |||
| 6. Add Local Response Normalization.</PackageReleaseNotes> | |||
| <LangVersion>7.3</LangVersion> | |||
| <FileVersion>0.11.0.0</FileVersion> | |||
| <FileVersion>0.11.1.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
| <SignAssembly>true</SignAssembly> | |||
| @@ -68,9 +68,9 @@ namespace Tensorflow | |||
| return defaultSession; | |||
| } | |||
| public Session Session(Graph graph) | |||
| public Session Session(Graph graph, SessionOptions opts = null) | |||
| { | |||
| return new Session(graph); | |||
| return new Session(graph, opts: opts); | |||
| } | |||
| public Session Session(SessionOptions opts) | |||
| @@ -1,17 +1,54 @@ | |||
| using System; | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| public class Dataset | |||
| { | |||
| string annot_path; | |||
| int[] input_sizes; | |||
| int batch_size; | |||
| bool data_aug; | |||
| int[] train_input_sizes; | |||
| NDArray strides; | |||
| NDArray anchors; | |||
| Dictionary<int, string> classes; | |||
| int num_classes; | |||
| int anchor_per_scale; | |||
| int max_bbox_per_scale; | |||
| string[] annotations; | |||
| int num_samples; | |||
| int batch_count; | |||
| public int Length = 0; | |||
| public Dataset(string dataset_type, Config cfg) | |||
| { | |||
| annot_path = dataset_type == "train" ? cfg.TRAIN.ANNOT_PATH : cfg.TEST.ANNOT_PATH; | |||
| input_sizes = dataset_type == "train" ? cfg.TRAIN.INPUT_SIZE : cfg.TEST.INPUT_SIZE; | |||
| batch_size = dataset_type == "train" ? cfg.TRAIN.BATCH_SIZE : cfg.TEST.BATCH_SIZE; | |||
| data_aug = dataset_type == "train" ? cfg.TRAIN.DATA_AUG : cfg.TEST.DATA_AUG; | |||
| train_input_sizes = cfg.TRAIN.INPUT_SIZE; | |||
| strides = np.array(cfg.YOLO.STRIDES); | |||
| classes = Utils.read_class_names(cfg.YOLO.CLASSES); | |||
| num_classes = classes.Count; | |||
| anchors = np.array(Utils.get_anchors(cfg.YOLO.ANCHORS)); | |||
| anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; | |||
| max_bbox_per_scale = 150; | |||
| annotations = load_annotations(); | |||
| num_samples = len(annotations); | |||
| batch_count = 0; | |||
| } | |||
| string[] load_annotations() | |||
| { | |||
| return File.ReadAllLines(annot_path); | |||
| } | |||
| } | |||
| } | |||
| @@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| /// </summary> | |||
| public class Main : IExample | |||
| { | |||
| public bool Enabled { get; set; } = false; | |||
| public bool Enabled { get; set; } = true; | |||
| public bool IsImportingGraph { get; set; } = false; | |||
| public string Name => "YOLOv3"; | |||
| @@ -41,7 +41,10 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| Tensor true_sbboxes; | |||
| Tensor true_mbboxes; | |||
| Tensor true_lbboxes; | |||
| Tensor trainable; | |||
| Tensor trainable; | |||
| Session sess; | |||
| YOLOv3 model; | |||
| #endregion | |||
| public bool Run() | |||
| @@ -50,7 +53,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | |||
| using (var sess = tf.Session(graph)) | |||
| var options = new SessionOptions(); | |||
| options.SetConfig(new ConfigProto { AllowSoftPlacement = true }); | |||
| using (var sess = tf.Session(graph, opts: options)) | |||
| { | |||
| Train(sess); | |||
| } | |||
| @@ -86,7 +91,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| tf_with(tf.name_scope("define_loss"), scope => | |||
| { | |||
| //model = new YOLOv3(input_data, trainable); | |||
| model = new YOLOv3(cfg, input_data, trainable); | |||
| }); | |||
| return graph; | |||
| @@ -109,9 +114,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| string dataDir = Path.Combine(Name, "data"); | |||
| Directory.CreateDirectory(dataDir); | |||
| classes = new Dictionary<int, string>(); | |||
| foreach (var line in File.ReadAllLines(cfg.YOLO.CLASSES)) | |||
| classes[classes.Count] = line; | |||
| classes = Utils.read_class_names(cfg.YOLO.CLASSES); | |||
| num_classes = classes.Count; | |||
| learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT; | |||
| @@ -0,0 +1,27 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| class Utils | |||
| { | |||
| public static Dictionary<int, string> read_class_names(string file) | |||
| { | |||
| var classes = new Dictionary<int, string>(); | |||
| foreach (var line in File.ReadAllLines(file)) | |||
| classes[classes.Count] = line; | |||
| return classes; | |||
| } | |||
| public static NDArray get_anchors(string file) | |||
| { | |||
| return np.array(File.ReadAllText(file).Split(',') | |||
| .Select(x => float.Parse(x)) | |||
| .ToArray()).reshape(3, 3, 2); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,10 +1,50 @@ | |||
| using System; | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| public class YOLOv3 | |||
| { | |||
| Config cfg; | |||
| Tensor trainable; | |||
| Tensor input_data; | |||
| Dictionary<int, string> classes; | |||
| int num_class; | |||
| NDArray strides; | |||
| NDArray anchors; | |||
| int anchor_per_scale; | |||
| float iou_loss_thresh; | |||
| string upsample_method; | |||
| Tensor conv_lbbox; | |||
| Tensor conv_mbbox; | |||
| Tensor conv_sbbox; | |||
| public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_) | |||
| { | |||
| cfg = cfg_; | |||
| input_data = input_data_; | |||
| trainable = trainable_; | |||
| classes = Utils.read_class_names(cfg.YOLO.CLASSES); | |||
| num_class = len(classes); | |||
| strides = np.array(cfg.YOLO.STRIDES); | |||
| anchors = Utils.get_anchors(cfg.YOLO.ANCHORS); | |||
| anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; | |||
| iou_loss_thresh = cfg.YOLO.IOU_LOSS_THRESH; | |||
| upsample_method = cfg.YOLO.UPSAMPLE_METHOD; | |||
| (conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data); | |||
| } | |||
| private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data) | |||
| { | |||
| Tensor route_1, route_2; | |||
| (route_1, route_2, input_data) = backbone.darknet53(input_data, trainable); | |||
| return (conv_lbbox, conv_mbbox, conv_sbbox); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,28 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| class backbone | |||
| { | |||
| public static (Tensor, Tensor, Tensor) darknet53(Tensor input_data, Tensor trainable) | |||
| { | |||
| return tf_with(tf.variable_scope("darknet"), scope => | |||
| { | |||
| input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 3, 32 }, trainable: trainable, name: "conv0"); | |||
| input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 32, 64 }, trainable: trainable, name: "conv1", downsample: true); | |||
| foreach (var i in range(1)) | |||
| input_data = common.residual_block(input_data, 64, 32, 64, trainable: trainable, name: $"residual{i + 0}"); | |||
| var route_1 = input_data; | |||
| var route_2 = input_data; | |||
| return (route_1, route_2, input_data); | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,72 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| class common | |||
| { | |||
| public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tensor trainable, | |||
| string name, bool downsample = false, bool activate = true, | |||
| bool bn = true) | |||
| { | |||
| return tf_with(tf.variable_scope(name), scope => | |||
| { | |||
| int[] strides; | |||
| string padding; | |||
| if (downsample) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| else | |||
| { | |||
| strides = new int[] { 1, 1, 1, 1 }; | |||
| padding = "SAME"; | |||
| } | |||
| var weight = tf.get_variable(name: "weight", dtype: tf.float32, trainable: true, | |||
| shape: filters_shape, initializer: tf.random_normal_initializer(stddev: 0.01f)); | |||
| var conv = tf.nn.conv2d(input: input_data, filter: weight, strides: strides, padding: padding); | |||
| if (bn) | |||
| { | |||
| conv = tf.layers.batch_normalization(conv, beta_initializer: tf.zeros_initializer, | |||
| gamma_initializer: tf.ones_initializer, | |||
| moving_mean_initializer: tf.zeros_initializer, | |||
| moving_variance_initializer: tf.ones_initializer, training: trainable); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| if (activate) | |||
| conv = tf.nn.leaky_relu(conv, alpha: 0.1f); | |||
| return conv; | |||
| }); | |||
| } | |||
| public static Tensor residual_block(Tensor input_data, int input_channel, int filter_num1, | |||
| int filter_num2, Tensor trainable, string name) | |||
| { | |||
| var short_cut = input_data; | |||
| return tf_with(tf.variable_scope(name), scope => | |||
| { | |||
| input_data = convolutional(input_data, filters_shape: new int[] { 1, 1, input_channel, filter_num1 }, | |||
| trainable: trainable, name: "conv1"); | |||
| input_data = convolutional(input_data, filters_shape: new int[] { 3, 3, filter_num1, filter_num2 }, | |||
| trainable: trainable, name: "conv2"); | |||
| var residual_output = input_data + short_cut; | |||
| return residual_output; | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -9,12 +9,13 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| { | |||
| public YoloConfig YOLO; | |||
| public TrainConfig TRAIN; | |||
| public TrainConfig TEST; | |||
| public TestConfig TEST; | |||
| public Config(string root) | |||
| { | |||
| YOLO = new YoloConfig(root); | |||
| TRAIN = new TrainConfig(root); | |||
| TEST = new TestConfig(root); | |||
| } | |||
| public class YoloConfig | |||
| @@ -22,13 +23,22 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| string _root; | |||
| public string CLASSES; | |||
| public string ANCHORS; | |||
| public float MOVING_AVE_DECAY = 0.9995f; | |||
| public int[] STRIDES = new int[] { 8, 16, 32 }; | |||
| public int ANCHOR_PER_SCALE = 3; | |||
| public float IOU_LOSS_THRESH = 0.5f; | |||
| public string UPSAMPLE_METHOD = "resize"; | |||
| public string ORIGINAL_WEIGHT; | |||
| public string DEMO_WEIGHT; | |||
| public YoloConfig(string root) | |||
| { | |||
| _root = root; | |||
| CLASSES = Path.Combine(_root, "data", "classes", "coco.names"); | |||
| ANCHORS = Path.Combine(_root, "data", "anchors", "basline_anchors.txt"); | |||
| ORIGINAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco.ckpt"); | |||
| DEMO_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt"); | |||
| } | |||
| } | |||
| @@ -54,5 +64,31 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
| ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); | |||
| } | |||
| } | |||
| public class TestConfig | |||
| { | |||
| string _root; | |||
| public int BATCH_SIZE = 2; | |||
| public int[] INPUT_SIZE = new int[] { 544 }; | |||
| public bool DATA_AUG = false; | |||
| public bool WRITE_IMAGE = true; | |||
| public string WRITE_IMAGE_PATH; | |||
| public string WEIGHT_FILE; | |||
| public bool WRITE_IMAGE_SHOW_LABEL = true; | |||
| public bool SHOW_LABEL = true; | |||
| public int SECOND_STAGE_EPOCHS = 30; | |||
| public float SCORE_THRESHOLD = 0.3f; | |||
| public float IOU_THRESHOLD = 0.45f; | |||
| public string ANNOT_PATH; | |||
| public TestConfig(string root) | |||
| { | |||
| _root = root; | |||
| ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_test.txt"); | |||
| WRITE_IMAGE_PATH = Path.Combine(_root, "data", "detection"); | |||
| WEIGHT_FILE = Path.Combine(_root, "checkpoint", "yolov3_test_loss=9.2099.ckpt-5"); | |||
| } | |||
| } | |||
| } | |||
| } | |||