Browse Source

common.convolutional, residual_block #359

tags/v0.12
Oceania2018 6 years ago
parent
commit
e1db2baca7
9 changed files with 260 additions and 16 deletions
  1. +5
    -4
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  2. +2
    -2
      src/TensorFlowNET.Core/tensorflow.cs
  3. +38
    -1
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs
  4. +10
    -7
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
  5. +27
    -0
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs
  6. +41
    -1
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs
  7. +28
    -0
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs
  8. +72
    -0
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs
  9. +37
    -1
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs

+ 5
- 4
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.0</TargetTensorFlow>
<Version>0.11.0</Version>
<Version>0.11.1</Version>
<Authors>Haiping Chen, Meinrad Recheis</Authors>
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -17,15 +17,16 @@
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
<Description>Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.11.0.0</AssemblyVersion>
<AssemblyVersion>0.11.1.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.10.0:
1. Upgrade NumSharp to v0.20.
2. Add DisposableObject class to manage object lifetime.
3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables.
4. Change tensorflow to non-static class in order to execute some initialization process.
5. Overloade session.run(), make syntax simpler.</PackageReleaseNotes>
5. Overload session.run(), make syntax simpler.
6. Add Local Response Normalization.</PackageReleaseNotes>
<LangVersion>7.3</LangVersion>
<FileVersion>0.11.0.0</FileVersion>
<FileVersion>0.11.1.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>


+ 2
- 2
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -68,9 +68,9 @@ namespace Tensorflow
return defaultSession;
}

public Session Session(Graph graph)
public Session Session(Graph graph, SessionOptions opts = null)
{
return new Session(graph);
return new Session(graph, opts: opts);
}

public Session Session(SessionOptions opts)


+ 38
- 1
test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs View File

@@ -1,17 +1,54 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using static Tensorflow.Binding;

namespace TensorFlowNET.Examples.ImageProcessing.YOLO
{
public class Dataset
{
string annot_path;
int[] input_sizes;
int batch_size;
bool data_aug;
int[] train_input_sizes;
NDArray strides;
NDArray anchors;
Dictionary<int, string> classes;
int num_classes;
int anchor_per_scale;
int max_bbox_per_scale;
string[] annotations;
int num_samples;
int batch_count;

public int Length = 0;

public Dataset(string dataset_type, Config cfg)
{
annot_path = dataset_type == "train" ? cfg.TRAIN.ANNOT_PATH : cfg.TEST.ANNOT_PATH;
input_sizes = dataset_type == "train" ? cfg.TRAIN.INPUT_SIZE : cfg.TEST.INPUT_SIZE;
batch_size = dataset_type == "train" ? cfg.TRAIN.BATCH_SIZE : cfg.TEST.BATCH_SIZE;
data_aug = dataset_type == "train" ? cfg.TRAIN.DATA_AUG : cfg.TEST.DATA_AUG;
train_input_sizes = cfg.TRAIN.INPUT_SIZE;
strides = np.array(cfg.YOLO.STRIDES);

classes = Utils.read_class_names(cfg.YOLO.CLASSES);
num_classes = classes.Count;
anchors = np.array(Utils.get_anchors(cfg.YOLO.ANCHORS));
anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE;
max_bbox_per_scale = 150;

annotations = load_annotations();
num_samples = len(annotations);
batch_count = 0;
}

string[] load_annotations()
{
return File.ReadAllLines(annot_path);
}
}
}

+ 10
- 7
test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs View File

@@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
/// </summary>
public class Main : IExample
{
public bool Enabled { get; set; } = false;
public bool Enabled { get; set; } = true;
public bool IsImportingGraph { get; set; } = false;
public string Name => "YOLOv3";

@@ -41,7 +41,10 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
Tensor true_sbboxes;
Tensor true_mbboxes;
Tensor true_lbboxes;
Tensor trainable;
Tensor trainable;

Session sess;
YOLOv3 model;
#endregion

public bool Run()
@@ -50,7 +53,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO

var graph = IsImportingGraph ? ImportGraph() : BuildGraph();

using (var sess = tf.Session(graph))
var options = new SessionOptions();
options.SetConfig(new ConfigProto { AllowSoftPlacement = true });
using (var sess = tf.Session(graph, opts: options))
{
Train(sess);
}
@@ -86,7 +91,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO

tf_with(tf.name_scope("define_loss"), scope =>
{
//model = new YOLOv3(input_data, trainable);
model = new YOLOv3(cfg, input_data, trainable);
});

return graph;
@@ -109,9 +114,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
string dataDir = Path.Combine(Name, "data");
Directory.CreateDirectory(dataDir);

classes = new Dictionary<int, string>();
foreach (var line in File.ReadAllLines(cfg.YOLO.CLASSES))
classes[classes.Count] = line;
classes = Utils.read_class_names(cfg.YOLO.CLASSES);
num_classes = classes.Count;

learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT;


+ 27
- 0
test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs View File

@@ -0,0 +1,27 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;

namespace TensorFlowNET.Examples.ImageProcessing.YOLO
{
class Utils
{
public static Dictionary<int, string> read_class_names(string file)
{
var classes = new Dictionary<int, string>();
foreach (var line in File.ReadAllLines(file))
classes[classes.Count] = line;
return classes;
}

public static NDArray get_anchors(string file)
{
return np.array(File.ReadAllText(file).Split(',')
.Select(x => float.Parse(x))
.ToArray()).reshape(3, 3, 2);
}
}
}

+ 41
- 1
test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs View File

@@ -1,10 +1,50 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.Examples.ImageProcessing.YOLO
{
public class YOLOv3
{
Config cfg;
Tensor trainable;
Tensor input_data;
Dictionary<int, string> classes;
int num_class;
NDArray strides;
NDArray anchors;
int anchor_per_scale;
float iou_loss_thresh;
string upsample_method;
Tensor conv_lbbox;
Tensor conv_mbbox;
Tensor conv_sbbox;

public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_)
{
cfg = cfg_;
input_data = input_data_;
trainable = trainable_;
classes = Utils.read_class_names(cfg.YOLO.CLASSES);
num_class = len(classes);
strides = np.array(cfg.YOLO.STRIDES);
anchors = Utils.get_anchors(cfg.YOLO.ANCHORS);
anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE;
iou_loss_thresh = cfg.YOLO.IOU_LOSS_THRESH;
upsample_method = cfg.YOLO.UPSAMPLE_METHOD;

(conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data);
}

private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data)
{
Tensor route_1, route_2;
(route_1, route_2, input_data) = backbone.darknet53(input_data, trainable);

return (conv_lbbox, conv_mbbox, conv_sbbox);
}
}
}

+ 28
- 0
test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs View File

@@ -0,0 +1,28 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.Examples.ImageProcessing.YOLO
{
class backbone
{
public static (Tensor, Tensor, Tensor) darknet53(Tensor input_data, Tensor trainable)
{
return tf_with(tf.variable_scope("darknet"), scope =>
{
input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 3, 32 }, trainable: trainable, name: "conv0");
input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 32, 64 }, trainable: trainable, name: "conv1", downsample: true);

foreach (var i in range(1))
input_data = common.residual_block(input_data, 64, 32, 64, trainable: trainable, name: $"residual{i + 0}");

var route_1 = input_data;
var route_2 = input_data;

return (route_1, route_2, input_data);
});
}
}
}

+ 72
- 0
test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs View File

@@ -0,0 +1,72 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.Examples.ImageProcessing.YOLO
{
class common
{
public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tensor trainable,
string name, bool downsample = false, bool activate = true,
bool bn = true)
{
return tf_with(tf.variable_scope(name), scope =>
{
int[] strides;
string padding;

if (downsample)
{
throw new NotImplementedException("");
}
else
{
strides = new int[] { 1, 1, 1, 1 };
padding = "SAME";
}

var weight = tf.get_variable(name: "weight", dtype: tf.float32, trainable: true,
shape: filters_shape, initializer: tf.random_normal_initializer(stddev: 0.01f));

var conv = tf.nn.conv2d(input: input_data, filter: weight, strides: strides, padding: padding);

if (bn)
{
conv = tf.layers.batch_normalization(conv, beta_initializer: tf.zeros_initializer,
gamma_initializer: tf.ones_initializer,
moving_mean_initializer: tf.zeros_initializer,
moving_variance_initializer: tf.ones_initializer, training: trainable);
}
else
{
throw new NotImplementedException("");
}

if (activate)
conv = tf.nn.leaky_relu(conv, alpha: 0.1f);

return conv;
});
}

public static Tensor residual_block(Tensor input_data, int input_channel, int filter_num1,
int filter_num2, Tensor trainable, string name)
{
var short_cut = input_data;

return tf_with(tf.variable_scope(name), scope =>
{
input_data = convolutional(input_data, filters_shape: new int[] { 1, 1, input_channel, filter_num1 },
trainable: trainable, name: "conv1");
input_data = convolutional(input_data, filters_shape: new int[] { 3, 3, filter_num1, filter_num2 },
trainable: trainable, name: "conv2");

var residual_output = input_data + short_cut;

return residual_output;
});
}
}
}

+ 37
- 1
test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs View File

@@ -9,12 +9,13 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
{
public YoloConfig YOLO;
public TrainConfig TRAIN;
public TrainConfig TEST;
public TestConfig TEST;

public Config(string root)
{
YOLO = new YoloConfig(root);
TRAIN = new TrainConfig(root);
TEST = new TestConfig(root);
}

public class YoloConfig
@@ -22,13 +23,22 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
string _root;

public string CLASSES;
public string ANCHORS;
public float MOVING_AVE_DECAY = 0.9995f;
public int[] STRIDES = new int[] { 8, 16, 32 };
public int ANCHOR_PER_SCALE = 3;
public float IOU_LOSS_THRESH = 0.5f;
public string UPSAMPLE_METHOD = "resize";
public string ORIGINAL_WEIGHT;
public string DEMO_WEIGHT;

public YoloConfig(string root)
{
_root = root;
CLASSES = Path.Combine(_root, "data", "classes", "coco.names");
ANCHORS = Path.Combine(_root, "data", "anchors", "basline_anchors.txt");
ORIGINAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco.ckpt");
DEMO_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt");
}
}

@@ -54,5 +64,31 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt");
}
}

public class TestConfig
{
string _root;

public int BATCH_SIZE = 2;
public int[] INPUT_SIZE = new int[] { 544 };
public bool DATA_AUG = false;
public bool WRITE_IMAGE = true;
public string WRITE_IMAGE_PATH;
public string WEIGHT_FILE;
public bool WRITE_IMAGE_SHOW_LABEL = true;
public bool SHOW_LABEL = true;
public int SECOND_STAGE_EPOCHS = 30;
public float SCORE_THRESHOLD = 0.3f;
public float IOU_THRESHOLD = 0.45f;
public string ANNOT_PATH;

public TestConfig(string root)
{
_root = root;
ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_test.txt");
WRITE_IMAGE_PATH = Path.Combine(_root, "data", "detection");
WEIGHT_FILE = Path.Combine(_root, "checkpoint", "yolov3_test_loss=9.2099.ckpt-5");
}
}
}
}

Loading…
Cancel
Save