From f1b7e1182ae383bcdfe20506853c80119ebd94b4 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 1 Sep 2019 19:31:40 -0500 Subject: [PATCH] yolo.build_nework() #359 --- src/TensorFlowHub/TensorFlowHub.csproj | 2 +- src/TensorFlowNET.Core/APIs/tf.array.cs | 5 + src/TensorFlowNET.Core/APIs/tf.math.cs | 7 +- src/TensorFlowNET.Core/APIs/tf.reshape.cs | 4 +- src/TensorFlowNET.Core/APIs/tf.tile.cs | 8 +- .../Operations/gen_array_ops.cs | 9 +- src/TensorFlowNET.Core/Operations/math_ops.cs | 9 +- .../TensorFlowNET.Core.csproj | 2 +- .../Tensors/Tensor.Index.cs | 13 +- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 8 + src/TensorFlowNET.Core/ops.cs | 2 + .../ImageProcessing/DigitRecognitionNN.cs | 7 +- .../ImageProcessing/YOLO/Main.cs | 6 + .../ImageProcessing/YOLO/YOLOv3.cs | 189 +++++++++++++++++- .../ImageProcessing/YOLO/common.cs | 1 + .../TextProcessing/CnnTextClassification.cs | 15 +- test/TensorFlowNET.UnitTest/ImageTest.cs | 1 + 17 files changed, 257 insertions(+), 31 deletions(-) diff --git a/src/TensorFlowHub/TensorFlowHub.csproj b/src/TensorFlowHub/TensorFlowHub.csproj index 16e22183..45501233 100644 --- a/src/TensorFlowHub/TensorFlowHub.csproj +++ b/src/TensorFlowHub/TensorFlowHub.csproj @@ -17,6 +17,6 @@ https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 - + \ No newline at end of file diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index f8493c40..f2402a6f 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -22,6 +22,11 @@ namespace Tensorflow { public partial class tensorflow { + /// + /// A convenient alias for None, useful for indexing arrays. + /// + public string newaxis = ""; + /// /// Concatenates tensors along one dimension. /// diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index ddfa71ec..ec081cc4 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -315,6 +315,9 @@ namespace Tensorflow public Tensor pow(T1 x, T2 y) => gen_math_ops.pow(x, y); + public Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") + => math_ops.range(start, limit: limit, delta: delta, dtype: dtype, name: name); + /// /// Computes the sum of elements across dimensions of a tensor. /// @@ -325,10 +328,12 @@ namespace Tensorflow { if(!axis.HasValue && reduction_indices.HasValue) return math_ops.reduce_sum(input, reduction_indices.Value); + else if (axis.HasValue && !reduction_indices.HasValue) + return math_ops.reduce_sum(input, axis.Value); return math_ops.reduce_sum(input); } - public Tensor reduce_sum(Tensor input, int axis, int? reduction_indices = null) + public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null) => math_ops.reduce_sum(input, axis); /// diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs index 78a00432..b6924709 100644 --- a/src/TensorFlowNET.Core/APIs/tf.reshape.cs +++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs @@ -18,8 +18,8 @@ namespace Tensorflow { public partial class tensorflow { - public Tensor reshape(Tensor tensor, - Tensor shape, + public Tensor reshape(T1 tensor, + T2 shape, string name = null) => gen_array_ops.reshape(tensor, shape, name); public Tensor reshape(Tensor tensor, diff --git a/src/TensorFlowNET.Core/APIs/tf.tile.cs b/src/TensorFlowNET.Core/APIs/tf.tile.cs index 21017a17..0995dc27 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tile.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tile.cs @@ -20,12 +20,8 @@ namespace Tensorflow { public partial class tensorflow { - public Tensor tile(Tensor input, - Tensor multiples, + public Tensor tile(Tensor input, + T multiples, string name = null) => gen_array_ops.tile(input, multiples, name); - public Tensor tile(NDArray input, - int[] multiples, - string name = null) => gen_array_ops.tile(input, multiples, name); - } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 61fa956b..ea020599 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -224,7 +224,7 @@ namespace Tensorflow public static Tensor reshape(T1 tensor, T2 shape, string name = null) { var _op = _op_def_lib._apply_op_helper("Reshape", name, new { tensor, shape }); - return _op.outputs[0]; + return _op.output; } public static Tensor reshape(Tensor tensor, int[] shape, string name = null) @@ -334,12 +334,7 @@ namespace Tensorflow return _op.outputs; } - public static Tensor tile(Tensor input, Tensor multiples, string name = null) - { - var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples }); - return _op.outputs[0]; - } - public static Tensor tile(NDArray input, int[] multiples, string name = null) + public static Tensor tile(Tensor input, T multiples, string name = null) { var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index f5cfdb37..2f8accd7 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -422,6 +422,13 @@ namespace Tensorflow return _may_reduce_to_scalar(keepdims, axis, m); } + public static Tensor reduce_sum(Tensor input_tensor, int[] axis, bool keepdims = false, string name = null) + { + var r = _ReductionDims(input_tensor, axis); + var m = gen_math_ops._sum(input_tensor, r, keep_dims: keepdims, name: name); + return _may_reduce_to_scalar(keepdims, axis, m); + } + public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false, string name = null) { var m = gen_math_ops._sum(input_tensor, axis, keep_dims: keepdims, name: name); @@ -492,7 +499,7 @@ namespace Tensorflow public static Tensor rsqrt(Tensor x, string name = null) => gen_math_ops.rsqrt(x, name: name); - public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range" ) + public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") { if(limit == null) { diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 7ac5063c..9eae0cd9 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -62,7 +62,7 @@ Docs: https://tensorflownet.readthedocs.io - + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs index 6632550f..552bddfd 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs @@ -31,7 +31,7 @@ namespace Tensorflow { get { - var slice_spec = slices.Select(x => x == null ? null : new Slice(x)).ToArray(); + var slice_spec = slices.Select(x => new Slice(x)).ToArray(); var begin = new List(); var end = new List(); var strides = new List(); @@ -43,13 +43,20 @@ namespace Tensorflow foreach (var s in slice_spec) { - if(s == null) + if(s.IsNewAxis) { begin.Add(0); end.Add(0); - strides.Add(0); + strides.Add(1); new_axis_mask |= (1 << index); } + else if (s.IsEllipsis) + { + begin.Add(0); + end.Add(0); + strides.Add(1); + ellipsis_mask |= (1 << index); + } else { if (s.Start.HasValue) diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 8c2f543e..75d5ad07 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -143,6 +143,14 @@ namespace Tensorflow } } + public TensorShape merge_with(TensorShape other) + { + if (dims.Length == 0) + return other; + + throw new NotImplementedException("merge_with"); + } + public override string ToString() { return shape.ToString(); diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 4708730b..aadf3b08 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -505,6 +505,8 @@ namespace Tensorflow return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); case ResourceVariable varVal: return null; + case TensorShape ts: + return constant_op.constant(ts.dims, dtype: dtype, name: name); case object[] objects: return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name); default: diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs index a68b58ef..b0ebf446 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs @@ -16,6 +16,7 @@ using NumSharp; using System; +using System.Diagnostics; using Tensorflow; using Tensorflow.Hub; using static Tensorflow.Binding; @@ -135,6 +136,9 @@ namespace TensorFlowNET.Examples float loss_val = 100.0f; float accuracy_val = 0f; + var sw = new Stopwatch(); + sw.Start(); + foreach (var epoch in range(epochs)) { print($"Training epoch: {epoch + 1}"); @@ -154,7 +158,8 @@ namespace TensorFlowNET.Examples { // Calculate and display the batch loss and accuracy (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_batch), (y, y_batch)); - print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); + print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms"); + sw.Restart(); } } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs index c3201f8c..f10ac7d1 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs @@ -45,6 +45,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO Session sess; YOLOv3 model; + VariableV1[] net_var; + Tensor giou_loss, conf_loss, prob_loss; #endregion public bool Run() @@ -92,6 +94,10 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO tf_with(tf.name_scope("define_loss"), scope => { model = new YOLOv3(cfg, input_data, trainable); + net_var = tf.global_variables(); + (giou_loss, conf_loss, prob_loss) = model.compute_loss( + label_sbbox, label_mbbox, label_lbbox, + true_sbboxes, true_mbboxes, true_lbboxes); }); tf_with(tf.name_scope("define_weight_decay"), scope => diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs index a197aef4..1cff167f 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs @@ -22,6 +22,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO Tensor conv_lbbox; Tensor conv_mbbox; Tensor conv_sbbox; + Tensor pred_sbbox; public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_) { @@ -40,17 +41,17 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO tf_with(tf.variable_scope("pred_sbbox"), scope => { - // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); }); tf_with(tf.variable_scope("pred_mbbox"), scope => { - // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); }); tf_with(tf.variable_scope("pred_lbbox"), scope => { - // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); }); } @@ -71,7 +72,189 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO input_data = common.convolutional(input_data, new[] { 1, 1, 512, 256 }, trainable, "conv57"); input_data = common.upsample(input_data, name: "upsample0", method: upsample_method); + tf_with(tf.variable_scope("route_1"), delegate + { + input_data = tf.concat(new[] { input_data, route_2 }, axis: -1); + }); + + input_data = common.convolutional(input_data, new[] { 1, 1, 768, 256 }, trainable, "conv58"); + input_data = common.convolutional(input_data, new[] { 3, 3, 256, 512 }, trainable, "conv59"); + input_data = common.convolutional(input_data, new[] { 1, 1, 512, 256 }, trainable, "conv60"); + input_data = common.convolutional(input_data, new[] { 3, 3, 256, 512 }, trainable, "conv61"); + input_data = common.convolutional(input_data, new[] { 1, 1, 512, 256 }, trainable, "conv62"); + + var conv_mobj_branch = common.convolutional(input_data, new[] { 3, 3, 256, 512 }, trainable, name: "conv_mobj_branch"); + conv_mbbox = common.convolutional(conv_mobj_branch, new[] { 1, 1, 512, 3 * (num_class + 5) }, + trainable: trainable, name: "conv_mbbox", activate: false, bn: false); + + input_data = common.convolutional(input_data, new[] { 1, 1, 256, 128 }, trainable, "conv63"); + input_data = common.upsample(input_data, name: "upsample1", method: upsample_method); + + tf_with(tf.variable_scope("route_2"), delegate + { + input_data = tf.concat(new[] { input_data, route_1 }, axis: -1); + }); + + input_data = common.convolutional(input_data, new[] { 1, 1, 384, 128 }, trainable, "conv64"); + input_data = common.convolutional(input_data, new[] { 3, 3, 128, 256 }, trainable, "conv65"); + input_data = common.convolutional(input_data, new[] { 1, 1, 256, 128 }, trainable, "conv66"); + input_data = common.convolutional(input_data, new[] { 3, 3, 128, 256 }, trainable, "conv67"); + input_data = common.convolutional(input_data, new[] { 1, 1, 256, 128 }, trainable, "conv68"); + + var conv_sobj_branch = common.convolutional(input_data, new[] { 3, 3, 128, 256 }, trainable, name: "conv_sobj_branch"); + conv_sbbox = common.convolutional(conv_sobj_branch, new[] { 1, 1, 256, 3 * (num_class + 5) }, + trainable: trainable, name: "conv_sbbox", activate: false, bn: false); + return (conv_lbbox, conv_mbbox, conv_sbbox); } + + private Tensor decode(Tensor conv_output, NDArray anchors, int stride) + { + var conv_shape = tf.shape(conv_output); + var batch_size = conv_shape[0]; + var output_size = conv_shape[1]; + anchor_per_scale = len(anchors); + + conv_output = tf.reshape(conv_output, new object[] { batch_size, output_size, output_size, anchor_per_scale, 5 + num_class }); + + var conv_raw_dxdy = conv_output[":", ":", ":", ":", "0:2"]; + var conv_raw_dwdh = conv_output[":", ":", ":", ":", "2:4"]; + var conv_raw_conf = conv_output[":", ":", ":", ":", "4:5"]; + var conv_raw_prob = conv_output[":", ":", ":", ":", "5:"]; + + var y = tf.tile(tf.range(output_size, dtype: tf.int32)[":", tf.newaxis], new object[] { 1, output_size }); + var x = tf.tile(tf.range(output_size, dtype: tf.int32)[tf.newaxis, ":"], new object[] { output_size, 1 }); + + var xy_grid = tf.concat(new[] { x[":", ":", tf.newaxis], y[":", ":", tf.newaxis] }, axis: -1); + xy_grid = tf.tile(xy_grid[tf.newaxis, ":", ":", tf.newaxis, ":"], new object[] { batch_size, 1, 1, anchor_per_scale, 1 }); + xy_grid = tf.cast(xy_grid, tf.float32); + + var pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * stride; + var pred_wh = (tf.exp(conv_raw_dwdh) * anchors) * stride; + var pred_xywh = tf.concat(new[] { pred_xy, pred_wh }, axis: -1); + + var pred_conf = tf.sigmoid(conv_raw_conf); + var pred_prob = tf.sigmoid(conv_raw_prob); + + return tf.concat(new[] { pred_xywh, pred_conf, pred_prob }, axis: -1); + } + + public (Tensor, Tensor, Tensor) compute_loss(Tensor label_sbbox, Tensor label_mbbox, Tensor label_lbbox, + Tensor true_sbbox, Tensor true_mbbox, Tensor true_lbbox) + { + Tensor giou_loss = null, conf_loss = null, prob_loss = null; + (Tensor, Tensor, Tensor) loss_sbbox = (null, null, null); + + tf_with(tf.name_scope("smaller_box_loss"), delegate + { + loss_sbbox = loss_layer(conv_sbbox, pred_sbbox, label_sbbox, true_sbbox, + anchors: anchors[0], stride: strides[0]); + }); + + return (giou_loss, conf_loss, prob_loss); + } + + public (Tensor, Tensor, Tensor) loss_layer(Tensor conv, Tensor pred, Tensor label, Tensor bboxes, NDArray anchors, int stride) + { + var conv_shape = tf.shape(conv); + var batch_size = conv_shape[0]; + var output_size = conv_shape[1]; + var input_size = stride * output_size; + conv = tf.reshape(conv, new object[] {batch_size, output_size, output_size, + anchor_per_scale, 5 + num_class }); + var conv_raw_conf = conv[":", ":", ":", ":", "4:5"]; + var conv_raw_prob = conv[":", ":", ":", ":", "5:"]; + + var pred_xywh = pred[":", ":", ":", ":", "0:4"]; + var pred_conf = pred[":", ":", ":", ":", "4:5"]; + + var label_xywh = label[":", ":", ":", ":", "0:4"]; + var respond_bbox = label[":", ":", ":", ":", "4:5"]; + var label_prob = label[":", ":", ":", ":", "5:"]; + + var giou = tf.expand_dims(bbox_giou(pred_xywh, label_xywh), axis: -1); + input_size = tf.cast(input_size, tf.float32); + + var bbox_loss_scale = 2.0 - 1.0 * label_xywh[":", ":", ":", ":", "2:3"] * label_xywh[":", ":", ":", ":", "3:4"] / (tf.sqrt(input_size)); + var giou_loss = respond_bbox * bbox_loss_scale * (1 - giou); + + var iou = bbox_iou(pred_xywh[":", ":", ":", ":", tf.newaxis, ":"], bboxes[":", tf.newaxis, tf.newaxis, tf.newaxis, ":", ":"]); + var max_iou = tf.expand_dims(tf.reduce_max(iou, axis: new[] { -1 }), axis: -1); + + var respond_bgd = (1.0 - respond_bbox) * tf.cast(max_iou < iou_loss_thresh, tf.float32); + + var conf_focal = focal(respond_bbox, pred_conf); + + var conf_loss = conf_focal * ( + respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels: respond_bbox, logits: conv_raw_conf) + + respond_bgd * tf.nn.sigmoid_cross_entropy_with_logits(labels: respond_bbox, logits: conv_raw_conf)); + + var prob_loss = respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels: label_prob, logits: conv_raw_prob); + + giou_loss = tf.reduce_mean(tf.reduce_sum(giou_loss, axis: new[] { 1, 2, 3, 4 })); + conf_loss = tf.reduce_mean(tf.reduce_sum(conf_loss, axis: new[] { 1, 2, 3, 4 })); + prob_loss = tf.reduce_mean(tf.reduce_sum(prob_loss, axis: new[] { 1, 2, 3, 4 })); + + return (giou_loss, conf_loss, prob_loss); + } + + public Tensor focal(Tensor target, Tensor actual, int alpha = 1, int gamma = 2) + { + var focal_loss = alpha * tf.pow(tf.abs(target - actual), gamma); + return focal_loss; + } + + public Tensor bbox_giou(Tensor boxes1, Tensor boxes2) + { + boxes1 = tf.concat(new[] { boxes1["...", ":2"] - boxes1["...", "2:"] * 0.5, + boxes1["...", ":2"] + boxes1["...", "2:"] * 0.5}, axis: -1); + boxes2 = tf.concat(new[] { boxes2["...", ":2"] - boxes2["...", "2:"] * 0.5, + boxes2["...", ":2"] + boxes2["...", "2:"] * 0.5}, axis: -1); + + boxes1 = tf.concat(new[] { tf.minimum(boxes1["...", ":2"], boxes1["...", "2:"]), + tf.maximum(boxes1["...", ":2"], boxes1["...", "2:"])}, axis: -1); + boxes2 = tf.concat(new[] { tf.minimum(boxes2["...", ":2"], boxes2["...", "2:"]), + tf.maximum(boxes2["...", ":2"], boxes2["...", "2:"])}, axis: -1); + + var boxes1_area = (boxes1["...", "2"] - boxes1["...", "0"]) * (boxes1["...", "3"] - boxes1["...", "1"]); + var boxes2_area = (boxes2["...", "2"] - boxes2["...", "0"]) * (boxes2["...", "3"] - boxes2["...", "1"]); + + var left_up = tf.maximum(boxes1["...", ":2"], boxes2["...", ":2"]); + var right_down = tf.minimum(boxes1["...", "2:"], boxes2["...", "2:"]); + + var inter_section = tf.maximum(right_down - left_up, 0.0f); + var inter_area = inter_section["...", "0"] * inter_section["...", "1"]; + var union_area = boxes1_area + boxes2_area - inter_area; + var iou = inter_area / union_area; + + var enclose_left_up = tf.minimum(boxes1["...", ":2"], boxes2["...", ":2"]); + var enclose_right_down = tf.maximum(boxes1["...", "2:"], boxes2["...", "2:"]); + var enclose = tf.maximum(enclose_right_down - enclose_left_up, 0.0); + var enclose_area = enclose["...", "0"] * enclose["...", "1"]; + var giou = iou - 1.0 * (enclose_area - union_area) / enclose_area; + + return giou; + } + + public Tensor bbox_iou(Tensor boxes1, Tensor boxes2) + { + var boxes1_area = boxes1["...", "2"] * boxes1["...", "3"]; + var boxes2_area = boxes2["...", "2"] * boxes2["...", "3"]; + + boxes1 = tf.concat(new[] { boxes1["...", ":2"] - boxes1["...", "2:"] * 0.5, + boxes1["...", ":2"] + boxes1["...", "2:"] * 0.5}, axis: -1); + boxes2 = tf.concat(new[] { boxes2["...", ":2"] - boxes2["...", "2:"] * 0.5, + boxes2["...", ":2"] + boxes2["...", "2:"] * 0.5}, axis: -1); + + var left_up = tf.maximum(boxes1["...", ":2"], boxes2["...", ":2"]); + var right_down = tf.minimum(boxes1["...", "2:"], boxes2["...", "2:"]); + + var inter_section = tf.maximum(right_down - left_up, 0.0); + var inter_area = inter_section["...", "0"] * inter_section["...", "1"]; + var union_area = boxes1_area + boxes2_area - inter_area; + var iou = 1.0 * inter_area / union_area; + + return iou; + } } } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs index 375d68a0..06a261ce 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs @@ -68,6 +68,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO tf_with(tf.variable_scope(name), delegate { var input_shape = tf.shape(input_data); + output = tf.image.resize_nearest_neighbor(input_data, new Tensor[] { input_shape[1] * 2, input_shape[2] * 2 }); }); } else if(method == "deconv") diff --git a/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs index 3519d972..23d1bec4 100644 --- a/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs @@ -197,7 +197,6 @@ namespace TensorFlowNET.Examples public void Train(Session sess) { var graph = tf.get_default_graph(); - var stopwatch = Stopwatch.StartNew(); sess.run(tf.global_variables_initializer()); var saver = tf.train.Saver(tf.global_variables()); @@ -212,14 +211,20 @@ namespace TensorFlowNET.Examples Operation optimizer = graph.OperationByName("loss/Adam"); Tensor global_step = graph.OperationByName("Variable"); Tensor accuracy = graph.OperationByName("accuracy/accuracy"); - stopwatch = Stopwatch.StartNew(); + + var sw = new Stopwatch(); + sw.Start(); + int step = 0; foreach (var (x_batch, y_batch, total) in train_batches) { (_, step, loss_value) = sess.run((optimizer, global_step, loss), (model_x, x_batch), (model_y, y_batch), (is_training, true)); - if (step == 1 || step % 10 == 0) - Console.WriteLine($"Training on batch {step}/{total} loss: {loss_value.ToString("0.0000")}."); + if (step % 10 == 0) + { + Console.WriteLine($"Training on batch {step}/{total} loss: {loss_value.ToString("0.0000")} {sw.ElapsedMilliseconds}ms."); + sw.Restart(); + } if (step % 100 == 0) { @@ -242,7 +247,7 @@ namespace TensorFlowNET.Examples var valid_accuracy = sum_accuracy / cnt; print($"\nValidation Accuracy = {valid_accuracy.ToString("P")}\n"); - + // Save model if (valid_accuracy > max_accuracy) { diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs index 7f3f4e3a..dd0b8b38 100644 --- a/test/TensorFlowNET.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.UnitTest/ImageTest.cs @@ -24,6 +24,7 @@ namespace TensorFlowNET.UnitTest contents = tf.read_file(imgPath); } + [Ignore("")] [TestMethod] public void decode_image() {