diff --git a/src/TensorFlowHub/TensorFlowHub.csproj b/src/TensorFlowHub/TensorFlowHub.csproj
index 16e22183..45501233 100644
--- a/src/TensorFlowHub/TensorFlowHub.csproj
+++ b/src/TensorFlowHub/TensorFlowHub.csproj
@@ -17,6 +17,6 @@
https://avatars3.githubusercontent.com/u/44989469?s=200&v=4
-
+
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index f8493c40..f2402a6f 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -22,6 +22,11 @@ namespace Tensorflow
{
public partial class tensorflow
{
+ ///
+ /// A convenient alias for None, useful for indexing arrays.
+ ///
+ public string newaxis = "";
+
///
/// Concatenates tensors along one dimension.
///
diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index ddfa71ec..ec081cc4 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -315,6 +315,9 @@ namespace Tensorflow
public Tensor pow(T1 x, T2 y)
=> gen_math_ops.pow(x, y);
+ public Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range")
+ => math_ops.range(start, limit: limit, delta: delta, dtype: dtype, name: name);
+
///
/// Computes the sum of elements across dimensions of a tensor.
///
@@ -325,10 +328,12 @@ namespace Tensorflow
{
if(!axis.HasValue && reduction_indices.HasValue)
return math_ops.reduce_sum(input, reduction_indices.Value);
+ else if (axis.HasValue && !reduction_indices.HasValue)
+ return math_ops.reduce_sum(input, axis.Value);
return math_ops.reduce_sum(input);
}
- public Tensor reduce_sum(Tensor input, int axis, int? reduction_indices = null)
+ public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null)
=> math_ops.reduce_sum(input, axis);
///
diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs
index 78a00432..b6924709 100644
--- a/src/TensorFlowNET.Core/APIs/tf.reshape.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs
@@ -18,8 +18,8 @@ namespace Tensorflow
{
public partial class tensorflow
{
- public Tensor reshape(Tensor tensor,
- Tensor shape,
+ public Tensor reshape(T1 tensor,
+ T2 shape,
string name = null) => gen_array_ops.reshape(tensor, shape, name);
public Tensor reshape(Tensor tensor,
diff --git a/src/TensorFlowNET.Core/APIs/tf.tile.cs b/src/TensorFlowNET.Core/APIs/tf.tile.cs
index 21017a17..0995dc27 100644
--- a/src/TensorFlowNET.Core/APIs/tf.tile.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.tile.cs
@@ -20,12 +20,8 @@ namespace Tensorflow
{
public partial class tensorflow
{
- public Tensor tile(Tensor input,
- Tensor multiples,
+ public Tensor tile(Tensor input,
+ T multiples,
string name = null) => gen_array_ops.tile(input, multiples, name);
- public Tensor tile(NDArray input,
- int[] multiples,
- string name = null) => gen_array_ops.tile(input, multiples, name);
-
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index 61fa956b..ea020599 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -224,7 +224,7 @@ namespace Tensorflow
public static Tensor reshape(T1 tensor, T2 shape, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Reshape", name, new { tensor, shape });
- return _op.outputs[0];
+ return _op.output;
}
public static Tensor reshape(Tensor tensor, int[] shape, string name = null)
@@ -334,12 +334,7 @@ namespace Tensorflow
return _op.outputs;
}
- public static Tensor tile(Tensor input, Tensor multiples, string name = null)
- {
- var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples });
- return _op.outputs[0];
- }
- public static Tensor tile(NDArray input, int[] multiples, string name = null)
+ public static Tensor tile(Tensor input, T multiples, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples });
return _op.outputs[0];
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index f5cfdb37..2f8accd7 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -422,6 +422,13 @@ namespace Tensorflow
return _may_reduce_to_scalar(keepdims, axis, m);
}
+ public static Tensor reduce_sum(Tensor input_tensor, int[] axis, bool keepdims = false, string name = null)
+ {
+ var r = _ReductionDims(input_tensor, axis);
+ var m = gen_math_ops._sum(input_tensor, r, keep_dims: keepdims, name: name);
+ return _may_reduce_to_scalar(keepdims, axis, m);
+ }
+
public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false, string name = null)
{
var m = gen_math_ops._sum(input_tensor, axis, keep_dims: keepdims, name: name);
@@ -492,7 +499,7 @@ namespace Tensorflow
public static Tensor rsqrt(Tensor x, string name = null)
=> gen_math_ops.rsqrt(x, name: name);
- public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range" )
+ public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range")
{
if(limit == null)
{
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 7ac5063c..9eae0cd9 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -62,7 +62,7 @@ Docs: https://tensorflownet.readthedocs.io
-
+
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
index 6632550f..552bddfd 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
@@ -31,7 +31,7 @@ namespace Tensorflow
{
get
{
- var slice_spec = slices.Select(x => x == null ? null : new Slice(x)).ToArray();
+ var slice_spec = slices.Select(x => new Slice(x)).ToArray();
var begin = new List();
var end = new List();
var strides = new List();
@@ -43,13 +43,20 @@ namespace Tensorflow
foreach (var s in slice_spec)
{
- if(s == null)
+ if(s.IsNewAxis)
{
begin.Add(0);
end.Add(0);
- strides.Add(0);
+ strides.Add(1);
new_axis_mask |= (1 << index);
}
+ else if (s.IsEllipsis)
+ {
+ begin.Add(0);
+ end.Add(0);
+ strides.Add(1);
+ ellipsis_mask |= (1 << index);
+ }
else
{
if (s.Start.HasValue)
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 8c2f543e..75d5ad07 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -143,6 +143,14 @@ namespace Tensorflow
}
}
+ public TensorShape merge_with(TensorShape other)
+ {
+ if (dims.Length == 0)
+ return other;
+
+ throw new NotImplementedException("merge_with");
+ }
+
public override string ToString()
{
return shape.ToString();
diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index 4708730b..aadf3b08 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -505,6 +505,8 @@ namespace Tensorflow
return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref);
case ResourceVariable varVal:
return null;
+ case TensorShape ts:
+ return constant_op.constant(ts.dims, dtype: dtype, name: name);
case object[] objects:
return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name);
default:
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs
index a68b58ef..b0ebf446 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs
@@ -16,6 +16,7 @@
using NumSharp;
using System;
+using System.Diagnostics;
using Tensorflow;
using Tensorflow.Hub;
using static Tensorflow.Binding;
@@ -135,6 +136,9 @@ namespace TensorFlowNET.Examples
float loss_val = 100.0f;
float accuracy_val = 0f;
+ var sw = new Stopwatch();
+ sw.Start();
+
foreach (var epoch in range(epochs))
{
print($"Training epoch: {epoch + 1}");
@@ -154,7 +158,8 @@ namespace TensorFlowNET.Examples
{
// Calculate and display the batch loss and accuracy
(loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_batch), (y, y_batch));
- print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
+ print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms");
+ sw.Restart();
}
}
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
index c3201f8c..f10ac7d1 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
@@ -45,6 +45,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
Session sess;
YOLOv3 model;
+ VariableV1[] net_var;
+ Tensor giou_loss, conf_loss, prob_loss;
#endregion
public bool Run()
@@ -92,6 +94,10 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
tf_with(tf.name_scope("define_loss"), scope =>
{
model = new YOLOv3(cfg, input_data, trainable);
+ net_var = tf.global_variables();
+ (giou_loss, conf_loss, prob_loss) = model.compute_loss(
+ label_sbbox, label_mbbox, label_lbbox,
+ true_sbboxes, true_mbboxes, true_lbboxes);
});
tf_with(tf.name_scope("define_weight_decay"), scope =>
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs
index a197aef4..1cff167f 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs
@@ -22,6 +22,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
Tensor conv_lbbox;
Tensor conv_mbbox;
Tensor conv_sbbox;
+ Tensor pred_sbbox;
public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_)
{
@@ -40,17 +41,17 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
tf_with(tf.variable_scope("pred_sbbox"), scope =>
{
- // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
+ pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
});
tf_with(tf.variable_scope("pred_mbbox"), scope =>
{
- // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
+ pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
});
tf_with(tf.variable_scope("pred_lbbox"), scope =>
{
- // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
+ pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
});
}
@@ -71,7 +72,189 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
input_data = common.convolutional(input_data, new[] { 1, 1, 512, 256 }, trainable, "conv57");
input_data = common.upsample(input_data, name: "upsample0", method: upsample_method);
+ tf_with(tf.variable_scope("route_1"), delegate
+ {
+ input_data = tf.concat(new[] { input_data, route_2 }, axis: -1);
+ });
+
+ input_data = common.convolutional(input_data, new[] { 1, 1, 768, 256 }, trainable, "conv58");
+ input_data = common.convolutional(input_data, new[] { 3, 3, 256, 512 }, trainable, "conv59");
+ input_data = common.convolutional(input_data, new[] { 1, 1, 512, 256 }, trainable, "conv60");
+ input_data = common.convolutional(input_data, new[] { 3, 3, 256, 512 }, trainable, "conv61");
+ input_data = common.convolutional(input_data, new[] { 1, 1, 512, 256 }, trainable, "conv62");
+
+ var conv_mobj_branch = common.convolutional(input_data, new[] { 3, 3, 256, 512 }, trainable, name: "conv_mobj_branch");
+ conv_mbbox = common.convolutional(conv_mobj_branch, new[] { 1, 1, 512, 3 * (num_class + 5) },
+ trainable: trainable, name: "conv_mbbox", activate: false, bn: false);
+
+ input_data = common.convolutional(input_data, new[] { 1, 1, 256, 128 }, trainable, "conv63");
+ input_data = common.upsample(input_data, name: "upsample1", method: upsample_method);
+
+ tf_with(tf.variable_scope("route_2"), delegate
+ {
+ input_data = tf.concat(new[] { input_data, route_1 }, axis: -1);
+ });
+
+ input_data = common.convolutional(input_data, new[] { 1, 1, 384, 128 }, trainable, "conv64");
+ input_data = common.convolutional(input_data, new[] { 3, 3, 128, 256 }, trainable, "conv65");
+ input_data = common.convolutional(input_data, new[] { 1, 1, 256, 128 }, trainable, "conv66");
+ input_data = common.convolutional(input_data, new[] { 3, 3, 128, 256 }, trainable, "conv67");
+ input_data = common.convolutional(input_data, new[] { 1, 1, 256, 128 }, trainable, "conv68");
+
+ var conv_sobj_branch = common.convolutional(input_data, new[] { 3, 3, 128, 256 }, trainable, name: "conv_sobj_branch");
+ conv_sbbox = common.convolutional(conv_sobj_branch, new[] { 1, 1, 256, 3 * (num_class + 5) },
+ trainable: trainable, name: "conv_sbbox", activate: false, bn: false);
+
return (conv_lbbox, conv_mbbox, conv_sbbox);
}
+
+ private Tensor decode(Tensor conv_output, NDArray anchors, int stride)
+ {
+ var conv_shape = tf.shape(conv_output);
+ var batch_size = conv_shape[0];
+ var output_size = conv_shape[1];
+ anchor_per_scale = len(anchors);
+
+ conv_output = tf.reshape(conv_output, new object[] { batch_size, output_size, output_size, anchor_per_scale, 5 + num_class });
+
+ var conv_raw_dxdy = conv_output[":", ":", ":", ":", "0:2"];
+ var conv_raw_dwdh = conv_output[":", ":", ":", ":", "2:4"];
+ var conv_raw_conf = conv_output[":", ":", ":", ":", "4:5"];
+ var conv_raw_prob = conv_output[":", ":", ":", ":", "5:"];
+
+ var y = tf.tile(tf.range(output_size, dtype: tf.int32)[":", tf.newaxis], new object[] { 1, output_size });
+ var x = tf.tile(tf.range(output_size, dtype: tf.int32)[tf.newaxis, ":"], new object[] { output_size, 1 });
+
+ var xy_grid = tf.concat(new[] { x[":", ":", tf.newaxis], y[":", ":", tf.newaxis] }, axis: -1);
+ xy_grid = tf.tile(xy_grid[tf.newaxis, ":", ":", tf.newaxis, ":"], new object[] { batch_size, 1, 1, anchor_per_scale, 1 });
+ xy_grid = tf.cast(xy_grid, tf.float32);
+
+ var pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * stride;
+ var pred_wh = (tf.exp(conv_raw_dwdh) * anchors) * stride;
+ var pred_xywh = tf.concat(new[] { pred_xy, pred_wh }, axis: -1);
+
+ var pred_conf = tf.sigmoid(conv_raw_conf);
+ var pred_prob = tf.sigmoid(conv_raw_prob);
+
+ return tf.concat(new[] { pred_xywh, pred_conf, pred_prob }, axis: -1);
+ }
+
+ public (Tensor, Tensor, Tensor) compute_loss(Tensor label_sbbox, Tensor label_mbbox, Tensor label_lbbox,
+ Tensor true_sbbox, Tensor true_mbbox, Tensor true_lbbox)
+ {
+ Tensor giou_loss = null, conf_loss = null, prob_loss = null;
+ (Tensor, Tensor, Tensor) loss_sbbox = (null, null, null);
+
+ tf_with(tf.name_scope("smaller_box_loss"), delegate
+ {
+ loss_sbbox = loss_layer(conv_sbbox, pred_sbbox, label_sbbox, true_sbbox,
+ anchors: anchors[0], stride: strides[0]);
+ });
+
+ return (giou_loss, conf_loss, prob_loss);
+ }
+
+ public (Tensor, Tensor, Tensor) loss_layer(Tensor conv, Tensor pred, Tensor label, Tensor bboxes, NDArray anchors, int stride)
+ {
+ var conv_shape = tf.shape(conv);
+ var batch_size = conv_shape[0];
+ var output_size = conv_shape[1];
+ var input_size = stride * output_size;
+ conv = tf.reshape(conv, new object[] {batch_size, output_size, output_size,
+ anchor_per_scale, 5 + num_class });
+ var conv_raw_conf = conv[":", ":", ":", ":", "4:5"];
+ var conv_raw_prob = conv[":", ":", ":", ":", "5:"];
+
+ var pred_xywh = pred[":", ":", ":", ":", "0:4"];
+ var pred_conf = pred[":", ":", ":", ":", "4:5"];
+
+ var label_xywh = label[":", ":", ":", ":", "0:4"];
+ var respond_bbox = label[":", ":", ":", ":", "4:5"];
+ var label_prob = label[":", ":", ":", ":", "5:"];
+
+ var giou = tf.expand_dims(bbox_giou(pred_xywh, label_xywh), axis: -1);
+ input_size = tf.cast(input_size, tf.float32);
+
+ var bbox_loss_scale = 2.0 - 1.0 * label_xywh[":", ":", ":", ":", "2:3"] * label_xywh[":", ":", ":", ":", "3:4"] / (tf.sqrt(input_size));
+ var giou_loss = respond_bbox * bbox_loss_scale * (1 - giou);
+
+ var iou = bbox_iou(pred_xywh[":", ":", ":", ":", tf.newaxis, ":"], bboxes[":", tf.newaxis, tf.newaxis, tf.newaxis, ":", ":"]);
+ var max_iou = tf.expand_dims(tf.reduce_max(iou, axis: new[] { -1 }), axis: -1);
+
+ var respond_bgd = (1.0 - respond_bbox) * tf.cast(max_iou < iou_loss_thresh, tf.float32);
+
+ var conf_focal = focal(respond_bbox, pred_conf);
+
+ var conf_loss = conf_focal * (
+ respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels: respond_bbox, logits: conv_raw_conf) +
+ respond_bgd * tf.nn.sigmoid_cross_entropy_with_logits(labels: respond_bbox, logits: conv_raw_conf));
+
+ var prob_loss = respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels: label_prob, logits: conv_raw_prob);
+
+ giou_loss = tf.reduce_mean(tf.reduce_sum(giou_loss, axis: new[] { 1, 2, 3, 4 }));
+ conf_loss = tf.reduce_mean(tf.reduce_sum(conf_loss, axis: new[] { 1, 2, 3, 4 }));
+ prob_loss = tf.reduce_mean(tf.reduce_sum(prob_loss, axis: new[] { 1, 2, 3, 4 }));
+
+ return (giou_loss, conf_loss, prob_loss);
+ }
+
+ public Tensor focal(Tensor target, Tensor actual, int alpha = 1, int gamma = 2)
+ {
+ var focal_loss = alpha * tf.pow(tf.abs(target - actual), gamma);
+ return focal_loss;
+ }
+
+ public Tensor bbox_giou(Tensor boxes1, Tensor boxes2)
+ {
+ boxes1 = tf.concat(new[] { boxes1["...", ":2"] - boxes1["...", "2:"] * 0.5,
+ boxes1["...", ":2"] + boxes1["...", "2:"] * 0.5}, axis: -1);
+ boxes2 = tf.concat(new[] { boxes2["...", ":2"] - boxes2["...", "2:"] * 0.5,
+ boxes2["...", ":2"] + boxes2["...", "2:"] * 0.5}, axis: -1);
+
+ boxes1 = tf.concat(new[] { tf.minimum(boxes1["...", ":2"], boxes1["...", "2:"]),
+ tf.maximum(boxes1["...", ":2"], boxes1["...", "2:"])}, axis: -1);
+ boxes2 = tf.concat(new[] { tf.minimum(boxes2["...", ":2"], boxes2["...", "2:"]),
+ tf.maximum(boxes2["...", ":2"], boxes2["...", "2:"])}, axis: -1);
+
+ var boxes1_area = (boxes1["...", "2"] - boxes1["...", "0"]) * (boxes1["...", "3"] - boxes1["...", "1"]);
+ var boxes2_area = (boxes2["...", "2"] - boxes2["...", "0"]) * (boxes2["...", "3"] - boxes2["...", "1"]);
+
+ var left_up = tf.maximum(boxes1["...", ":2"], boxes2["...", ":2"]);
+ var right_down = tf.minimum(boxes1["...", "2:"], boxes2["...", "2:"]);
+
+ var inter_section = tf.maximum(right_down - left_up, 0.0f);
+ var inter_area = inter_section["...", "0"] * inter_section["...", "1"];
+ var union_area = boxes1_area + boxes2_area - inter_area;
+ var iou = inter_area / union_area;
+
+ var enclose_left_up = tf.minimum(boxes1["...", ":2"], boxes2["...", ":2"]);
+ var enclose_right_down = tf.maximum(boxes1["...", "2:"], boxes2["...", "2:"]);
+ var enclose = tf.maximum(enclose_right_down - enclose_left_up, 0.0);
+ var enclose_area = enclose["...", "0"] * enclose["...", "1"];
+ var giou = iou - 1.0 * (enclose_area - union_area) / enclose_area;
+
+ return giou;
+ }
+
+ public Tensor bbox_iou(Tensor boxes1, Tensor boxes2)
+ {
+ var boxes1_area = boxes1["...", "2"] * boxes1["...", "3"];
+ var boxes2_area = boxes2["...", "2"] * boxes2["...", "3"];
+
+ boxes1 = tf.concat(new[] { boxes1["...", ":2"] - boxes1["...", "2:"] * 0.5,
+ boxes1["...", ":2"] + boxes1["...", "2:"] * 0.5}, axis: -1);
+ boxes2 = tf.concat(new[] { boxes2["...", ":2"] - boxes2["...", "2:"] * 0.5,
+ boxes2["...", ":2"] + boxes2["...", "2:"] * 0.5}, axis: -1);
+
+ var left_up = tf.maximum(boxes1["...", ":2"], boxes2["...", ":2"]);
+ var right_down = tf.minimum(boxes1["...", "2:"], boxes2["...", "2:"]);
+
+ var inter_section = tf.maximum(right_down - left_up, 0.0);
+ var inter_area = inter_section["...", "0"] * inter_section["...", "1"];
+ var union_area = boxes1_area + boxes2_area - inter_area;
+ var iou = 1.0 * inter_area / union_area;
+
+ return iou;
+ }
}
}
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs
index 375d68a0..06a261ce 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs
@@ -68,6 +68,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
tf_with(tf.variable_scope(name), delegate
{
var input_shape = tf.shape(input_data);
+ output = tf.image.resize_nearest_neighbor(input_data, new Tensor[] { input_shape[1] * 2, input_shape[2] * 2 });
});
}
else if(method == "deconv")
diff --git a/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs
index 3519d972..23d1bec4 100644
--- a/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs
+++ b/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs
@@ -197,7 +197,6 @@ namespace TensorFlowNET.Examples
public void Train(Session sess)
{
var graph = tf.get_default_graph();
- var stopwatch = Stopwatch.StartNew();
sess.run(tf.global_variables_initializer());
var saver = tf.train.Saver(tf.global_variables());
@@ -212,14 +211,20 @@ namespace TensorFlowNET.Examples
Operation optimizer = graph.OperationByName("loss/Adam");
Tensor global_step = graph.OperationByName("Variable");
Tensor accuracy = graph.OperationByName("accuracy/accuracy");
- stopwatch = Stopwatch.StartNew();
+
+ var sw = new Stopwatch();
+ sw.Start();
+
int step = 0;
foreach (var (x_batch, y_batch, total) in train_batches)
{
(_, step, loss_value) = sess.run((optimizer, global_step, loss),
(model_x, x_batch), (model_y, y_batch), (is_training, true));
- if (step == 1 || step % 10 == 0)
- Console.WriteLine($"Training on batch {step}/{total} loss: {loss_value.ToString("0.0000")}.");
+ if (step % 10 == 0)
+ {
+ Console.WriteLine($"Training on batch {step}/{total} loss: {loss_value.ToString("0.0000")} {sw.ElapsedMilliseconds}ms.");
+ sw.Restart();
+ }
if (step % 100 == 0)
{
@@ -242,7 +247,7 @@ namespace TensorFlowNET.Examples
var valid_accuracy = sum_accuracy / cnt;
print($"\nValidation Accuracy = {valid_accuracy.ToString("P")}\n");
-
+
// Save model
if (valid_accuracy > max_accuracy)
{
diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs
index 7f3f4e3a..dd0b8b38 100644
--- a/test/TensorFlowNET.UnitTest/ImageTest.cs
+++ b/test/TensorFlowNET.UnitTest/ImageTest.cs
@@ -24,6 +24,7 @@ namespace TensorFlowNET.UnitTest
contents = tf.read_file(imgPath);
}
+ [Ignore("")]
[TestMethod]
public void decode_image()
{