Browse Source

yolo.build_nework() #359

tags/v0.12
Oceania2018 6 years ago
parent
commit
f1b7e1182a
17 changed files with 257 additions and 31 deletions
  1. +1
    -1
      src/TensorFlowHub/TensorFlowHub.csproj
  2. +5
    -0
      src/TensorFlowNET.Core/APIs/tf.array.cs
  3. +6
    -1
      src/TensorFlowNET.Core/APIs/tf.math.cs
  4. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.reshape.cs
  5. +2
    -6
      src/TensorFlowNET.Core/APIs/tf.tile.cs
  6. +2
    -7
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  7. +8
    -1
      src/TensorFlowNET.Core/Operations/math_ops.cs
  8. +1
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  9. +10
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
  10. +8
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  11. +2
    -0
      src/TensorFlowNET.Core/ops.cs
  12. +6
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs
  13. +6
    -0
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
  14. +186
    -3
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs
  15. +1
    -0
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs
  16. +10
    -5
      test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs
  17. +1
    -0
      test/TensorFlowNET.UnitTest/ImageTest.cs

+ 1
- 1
src/TensorFlowHub/TensorFlowHub.csproj View File

@@ -17,6 +17,6 @@
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="NumSharp" Version="0.20.0" />
<PackageReference Include="NumSharp" Version="0.20.1" />
</ItemGroup>
</Project>

+ 5
- 0
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -22,6 +22,11 @@ namespace Tensorflow
{
public partial class tensorflow
{
/// <summary>
/// A convenient alias for None, useful for indexing arrays.
/// </summary>
public string newaxis = "";

/// <summary>
/// Concatenates tensors along one dimension.
/// </summary>


+ 6
- 1
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -315,6 +315,9 @@ namespace Tensorflow
public Tensor pow<T1, T2>(T1 x, T2 y)
=> gen_math_ops.pow(x, y);

public Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range")
=> math_ops.range(start, limit: limit, delta: delta, dtype: dtype, name: name);

/// <summary>
/// Computes the sum of elements across dimensions of a tensor.
/// </summary>
@@ -325,10 +328,12 @@ namespace Tensorflow
{
if(!axis.HasValue && reduction_indices.HasValue)
return math_ops.reduce_sum(input, reduction_indices.Value);
else if (axis.HasValue && !reduction_indices.HasValue)
return math_ops.reduce_sum(input, axis.Value);
return math_ops.reduce_sum(input);
}

public Tensor reduce_sum(Tensor input, int axis, int? reduction_indices = null)
public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null)
=> math_ops.reduce_sum(input, axis);

/// <summary>


+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.reshape.cs View File

@@ -18,8 +18,8 @@ namespace Tensorflow
{
public partial class tensorflow
{
public Tensor reshape(Tensor tensor,
Tensor shape,
public Tensor reshape<T1, T2>(T1 tensor,
T2 shape,
string name = null) => gen_array_ops.reshape(tensor, shape, name);

public Tensor reshape(Tensor tensor,


+ 2
- 6
src/TensorFlowNET.Core/APIs/tf.tile.cs View File

@@ -20,12 +20,8 @@ namespace Tensorflow
{
public partial class tensorflow
{
public Tensor tile(Tensor input,
Tensor multiples,
public Tensor tile<T>(Tensor input,
T multiples,
string name = null) => gen_array_ops.tile(input, multiples, name);
public Tensor tile(NDArray input,
int[] multiples,
string name = null) => gen_array_ops.tile(input, multiples, name);

}
}

+ 2
- 7
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -224,7 +224,7 @@ namespace Tensorflow
public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Reshape", name, new { tensor, shape });
return _op.outputs[0];
return _op.output;
}

public static Tensor reshape(Tensor tensor, int[] shape, string name = null)
@@ -334,12 +334,7 @@ namespace Tensorflow
return _op.outputs;
}

public static Tensor tile(Tensor input, Tensor multiples, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples });
return _op.outputs[0];
}
public static Tensor tile(NDArray input, int[] multiples, string name = null)
public static Tensor tile<T>(Tensor input, T multiples, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples });
return _op.outputs[0];


+ 8
- 1
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -422,6 +422,13 @@ namespace Tensorflow
return _may_reduce_to_scalar(keepdims, axis, m);
}

public static Tensor reduce_sum(Tensor input_tensor, int[] axis, bool keepdims = false, string name = null)
{
var r = _ReductionDims(input_tensor, axis);
var m = gen_math_ops._sum(input_tensor, r, keep_dims: keepdims, name: name);
return _may_reduce_to_scalar(keepdims, axis, m);
}

public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false, string name = null)
{
var m = gen_math_ops._sum(input_tensor, axis, keep_dims: keepdims, name: name);
@@ -492,7 +499,7 @@ namespace Tensorflow
public static Tensor rsqrt(Tensor x, string name = null)
=> gen_math_ops.rsqrt(x, name: name);

public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range" )
public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range")
{
if(limit == null)
{


+ 1
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -62,7 +62,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.5.1" />
<PackageReference Include="NumSharp" Version="0.20.0" />
<PackageReference Include="NumSharp" Version="0.20.1" />
</ItemGroup>

<ItemGroup>


+ 10
- 3
src/TensorFlowNET.Core/Tensors/Tensor.Index.cs View File

@@ -31,7 +31,7 @@ namespace Tensorflow
{
get
{
var slice_spec = slices.Select(x => x == null ? null : new Slice(x)).ToArray();
var slice_spec = slices.Select(x => new Slice(x)).ToArray();
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();
@@ -43,13 +43,20 @@ namespace Tensorflow

foreach (var s in slice_spec)
{
if(s == null)
if(s.IsNewAxis)
{
begin.Add(0);
end.Add(0);
strides.Add(0);
strides.Add(1);
new_axis_mask |= (1 << index);
}
else if (s.IsEllipsis)
{
begin.Add(0);
end.Add(0);
strides.Add(1);
ellipsis_mask |= (1 << index);
}
else
{
if (s.Start.HasValue)


+ 8
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -143,6 +143,14 @@ namespace Tensorflow
}
}

public TensorShape merge_with(TensorShape other)
{
if (dims.Length == 0)
return other;

throw new NotImplementedException("merge_with");
}

public override string ToString()
{
return shape.ToString();


+ 2
- 0
src/TensorFlowNET.Core/ops.cs View File

@@ -505,6 +505,8 @@ namespace Tensorflow
return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref);
case ResourceVariable varVal:
return null;
case TensorShape ts:
return constant_op.constant(ts.dims, dtype: dtype, name: name);
case object[] objects:
return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name);
default:


+ 6
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs View File

@@ -16,6 +16,7 @@

using NumSharp;
using System;
using System.Diagnostics;
using Tensorflow;
using Tensorflow.Hub;
using static Tensorflow.Binding;
@@ -135,6 +136,9 @@ namespace TensorFlowNET.Examples
float loss_val = 100.0f;
float accuracy_val = 0f;

var sw = new Stopwatch();
sw.Start();

foreach (var epoch in range(epochs))
{
print($"Training epoch: {epoch + 1}");
@@ -154,7 +158,8 @@ namespace TensorFlowNET.Examples
{
// Calculate and display the batch loss and accuracy
(loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_batch), (y, y_batch));
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms");
sw.Restart();
}
}



+ 6
- 0
test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs View File

@@ -45,6 +45,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO

Session sess;
YOLOv3 model;
VariableV1[] net_var;
Tensor giou_loss, conf_loss, prob_loss;
#endregion

public bool Run()
@@ -92,6 +94,10 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
tf_with(tf.name_scope("define_loss"), scope =>
{
model = new YOLOv3(cfg, input_data, trainable);
net_var = tf.global_variables();
(giou_loss, conf_loss, prob_loss) = model.compute_loss(
label_sbbox, label_mbbox, label_lbbox,
true_sbboxes, true_mbboxes, true_lbboxes);
});

tf_with(tf.name_scope("define_weight_decay"), scope =>


+ 186
- 3
test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs View File

@@ -22,6 +22,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
Tensor conv_lbbox;
Tensor conv_mbbox;
Tensor conv_sbbox;
Tensor pred_sbbox;

public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_)
{
@@ -40,17 +41,17 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO

tf_with(tf.variable_scope("pred_sbbox"), scope =>
{
// pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
});

tf_with(tf.variable_scope("pred_mbbox"), scope =>
{
// pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
});

tf_with(tf.variable_scope("pred_lbbox"), scope =>
{
// pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
});
}

@@ -71,7 +72,189 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
input_data = common.convolutional(input_data, new[] { 1, 1, 512, 256 }, trainable, "conv57");
input_data = common.upsample(input_data, name: "upsample0", method: upsample_method);

tf_with(tf.variable_scope("route_1"), delegate
{
input_data = tf.concat(new[] { input_data, route_2 }, axis: -1);
});

input_data = common.convolutional(input_data, new[] { 1, 1, 768, 256 }, trainable, "conv58");
input_data = common.convolutional(input_data, new[] { 3, 3, 256, 512 }, trainable, "conv59");
input_data = common.convolutional(input_data, new[] { 1, 1, 512, 256 }, trainable, "conv60");
input_data = common.convolutional(input_data, new[] { 3, 3, 256, 512 }, trainable, "conv61");
input_data = common.convolutional(input_data, new[] { 1, 1, 512, 256 }, trainable, "conv62");

var conv_mobj_branch = common.convolutional(input_data, new[] { 3, 3, 256, 512 }, trainable, name: "conv_mobj_branch");
conv_mbbox = common.convolutional(conv_mobj_branch, new[] { 1, 1, 512, 3 * (num_class + 5) },
trainable: trainable, name: "conv_mbbox", activate: false, bn: false);

input_data = common.convolutional(input_data, new[] { 1, 1, 256, 128 }, trainable, "conv63");
input_data = common.upsample(input_data, name: "upsample1", method: upsample_method);

tf_with(tf.variable_scope("route_2"), delegate
{
input_data = tf.concat(new[] { input_data, route_1 }, axis: -1);
});

input_data = common.convolutional(input_data, new[] { 1, 1, 384, 128 }, trainable, "conv64");
input_data = common.convolutional(input_data, new[] { 3, 3, 128, 256 }, trainable, "conv65");
input_data = common.convolutional(input_data, new[] { 1, 1, 256, 128 }, trainable, "conv66");
input_data = common.convolutional(input_data, new[] { 3, 3, 128, 256 }, trainable, "conv67");
input_data = common.convolutional(input_data, new[] { 1, 1, 256, 128 }, trainable, "conv68");

var conv_sobj_branch = common.convolutional(input_data, new[] { 3, 3, 128, 256 }, trainable, name: "conv_sobj_branch");
conv_sbbox = common.convolutional(conv_sobj_branch, new[] { 1, 1, 256, 3 * (num_class + 5) },
trainable: trainable, name: "conv_sbbox", activate: false, bn: false);
return (conv_lbbox, conv_mbbox, conv_sbbox);
}

private Tensor decode(Tensor conv_output, NDArray anchors, int stride)
{
var conv_shape = tf.shape(conv_output);
var batch_size = conv_shape[0];
var output_size = conv_shape[1];
anchor_per_scale = len(anchors);

conv_output = tf.reshape(conv_output, new object[] { batch_size, output_size, output_size, anchor_per_scale, 5 + num_class });

var conv_raw_dxdy = conv_output[":", ":", ":", ":", "0:2"];
var conv_raw_dwdh = conv_output[":", ":", ":", ":", "2:4"];
var conv_raw_conf = conv_output[":", ":", ":", ":", "4:5"];
var conv_raw_prob = conv_output[":", ":", ":", ":", "5:"];

var y = tf.tile(tf.range(output_size, dtype: tf.int32)[":", tf.newaxis], new object[] { 1, output_size });
var x = tf.tile(tf.range(output_size, dtype: tf.int32)[tf.newaxis, ":"], new object[] { output_size, 1 });

var xy_grid = tf.concat(new[] { x[":", ":", tf.newaxis], y[":", ":", tf.newaxis] }, axis: -1);
xy_grid = tf.tile(xy_grid[tf.newaxis, ":", ":", tf.newaxis, ":"], new object[] { batch_size, 1, 1, anchor_per_scale, 1 });
xy_grid = tf.cast(xy_grid, tf.float32);

var pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * stride;
var pred_wh = (tf.exp(conv_raw_dwdh) * anchors) * stride;
var pred_xywh = tf.concat(new[] { pred_xy, pred_wh }, axis: -1);

var pred_conf = tf.sigmoid(conv_raw_conf);
var pred_prob = tf.sigmoid(conv_raw_prob);

return tf.concat(new[] { pred_xywh, pred_conf, pred_prob }, axis: -1);
}

public (Tensor, Tensor, Tensor) compute_loss(Tensor label_sbbox, Tensor label_mbbox, Tensor label_lbbox,
Tensor true_sbbox, Tensor true_mbbox, Tensor true_lbbox)
{
Tensor giou_loss = null, conf_loss = null, prob_loss = null;
(Tensor, Tensor, Tensor) loss_sbbox = (null, null, null);

tf_with(tf.name_scope("smaller_box_loss"), delegate
{
loss_sbbox = loss_layer(conv_sbbox, pred_sbbox, label_sbbox, true_sbbox,
anchors: anchors[0], stride: strides[0]);
});

return (giou_loss, conf_loss, prob_loss);
}

public (Tensor, Tensor, Tensor) loss_layer(Tensor conv, Tensor pred, Tensor label, Tensor bboxes, NDArray anchors, int stride)
{
var conv_shape = tf.shape(conv);
var batch_size = conv_shape[0];
var output_size = conv_shape[1];
var input_size = stride * output_size;
conv = tf.reshape(conv, new object[] {batch_size, output_size, output_size,
anchor_per_scale, 5 + num_class });
var conv_raw_conf = conv[":", ":", ":", ":", "4:5"];
var conv_raw_prob = conv[":", ":", ":", ":", "5:"];

var pred_xywh = pred[":", ":", ":", ":", "0:4"];
var pred_conf = pred[":", ":", ":", ":", "4:5"];

var label_xywh = label[":", ":", ":", ":", "0:4"];
var respond_bbox = label[":", ":", ":", ":", "4:5"];
var label_prob = label[":", ":", ":", ":", "5:"];

var giou = tf.expand_dims(bbox_giou(pred_xywh, label_xywh), axis: -1);
input_size = tf.cast(input_size, tf.float32);

var bbox_loss_scale = 2.0 - 1.0 * label_xywh[":", ":", ":", ":", "2:3"] * label_xywh[":", ":", ":", ":", "3:4"] / (tf.sqrt(input_size));
var giou_loss = respond_bbox * bbox_loss_scale * (1 - giou);

var iou = bbox_iou(pred_xywh[":", ":", ":", ":", tf.newaxis, ":"], bboxes[":", tf.newaxis, tf.newaxis, tf.newaxis, ":", ":"]);
var max_iou = tf.expand_dims(tf.reduce_max(iou, axis: new[] { -1 }), axis: -1);

var respond_bgd = (1.0 - respond_bbox) * tf.cast(max_iou < iou_loss_thresh, tf.float32);

var conf_focal = focal(respond_bbox, pred_conf);

var conf_loss = conf_focal * (
respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels: respond_bbox, logits: conv_raw_conf) +
respond_bgd * tf.nn.sigmoid_cross_entropy_with_logits(labels: respond_bbox, logits: conv_raw_conf));

var prob_loss = respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels: label_prob, logits: conv_raw_prob);

giou_loss = tf.reduce_mean(tf.reduce_sum(giou_loss, axis: new[] { 1, 2, 3, 4 }));
conf_loss = tf.reduce_mean(tf.reduce_sum(conf_loss, axis: new[] { 1, 2, 3, 4 }));
prob_loss = tf.reduce_mean(tf.reduce_sum(prob_loss, axis: new[] { 1, 2, 3, 4 }));

return (giou_loss, conf_loss, prob_loss);
}

public Tensor focal(Tensor target, Tensor actual, int alpha = 1, int gamma = 2)
{
var focal_loss = alpha * tf.pow(tf.abs(target - actual), gamma);
return focal_loss;
}

public Tensor bbox_giou(Tensor boxes1, Tensor boxes2)
{
boxes1 = tf.concat(new[] { boxes1["...", ":2"] - boxes1["...", "2:"] * 0.5,
boxes1["...", ":2"] + boxes1["...", "2:"] * 0.5}, axis: -1);
boxes2 = tf.concat(new[] { boxes2["...", ":2"] - boxes2["...", "2:"] * 0.5,
boxes2["...", ":2"] + boxes2["...", "2:"] * 0.5}, axis: -1);

boxes1 = tf.concat(new[] { tf.minimum(boxes1["...", ":2"], boxes1["...", "2:"]),
tf.maximum(boxes1["...", ":2"], boxes1["...", "2:"])}, axis: -1);
boxes2 = tf.concat(new[] { tf.minimum(boxes2["...", ":2"], boxes2["...", "2:"]),
tf.maximum(boxes2["...", ":2"], boxes2["...", "2:"])}, axis: -1);

var boxes1_area = (boxes1["...", "2"] - boxes1["...", "0"]) * (boxes1["...", "3"] - boxes1["...", "1"]);
var boxes2_area = (boxes2["...", "2"] - boxes2["...", "0"]) * (boxes2["...", "3"] - boxes2["...", "1"]);

var left_up = tf.maximum(boxes1["...", ":2"], boxes2["...", ":2"]);
var right_down = tf.minimum(boxes1["...", "2:"], boxes2["...", "2:"]);

var inter_section = tf.maximum(right_down - left_up, 0.0f);
var inter_area = inter_section["...", "0"] * inter_section["...", "1"];
var union_area = boxes1_area + boxes2_area - inter_area;
var iou = inter_area / union_area;

var enclose_left_up = tf.minimum(boxes1["...", ":2"], boxes2["...", ":2"]);
var enclose_right_down = tf.maximum(boxes1["...", "2:"], boxes2["...", "2:"]);
var enclose = tf.maximum(enclose_right_down - enclose_left_up, 0.0);
var enclose_area = enclose["...", "0"] * enclose["...", "1"];
var giou = iou - 1.0 * (enclose_area - union_area) / enclose_area;

return giou;
}

public Tensor bbox_iou(Tensor boxes1, Tensor boxes2)
{
var boxes1_area = boxes1["...", "2"] * boxes1["...", "3"];
var boxes2_area = boxes2["...", "2"] * boxes2["...", "3"];

boxes1 = tf.concat(new[] { boxes1["...", ":2"] - boxes1["...", "2:"] * 0.5,
boxes1["...", ":2"] + boxes1["...", "2:"] * 0.5}, axis: -1);
boxes2 = tf.concat(new[] { boxes2["...", ":2"] - boxes2["...", "2:"] * 0.5,
boxes2["...", ":2"] + boxes2["...", "2:"] * 0.5}, axis: -1);

var left_up = tf.maximum(boxes1["...", ":2"], boxes2["...", ":2"]);
var right_down = tf.minimum(boxes1["...", "2:"], boxes2["...", "2:"]);

var inter_section = tf.maximum(right_down - left_up, 0.0);
var inter_area = inter_section["...", "0"] * inter_section["...", "1"];
var union_area = boxes1_area + boxes2_area - inter_area;
var iou = 1.0 * inter_area / union_area;

return iou;
}
}
}

+ 1
- 0
test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs View File

@@ -68,6 +68,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
tf_with(tf.variable_scope(name), delegate
{
var input_shape = tf.shape(input_data);
output = tf.image.resize_nearest_neighbor(input_data, new Tensor[] { input_shape[1] * 2, input_shape[2] * 2 });
});
}
else if(method == "deconv")


+ 10
- 5
test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs View File

@@ -197,7 +197,6 @@ namespace TensorFlowNET.Examples
public void Train(Session sess)
{
var graph = tf.get_default_graph();
var stopwatch = Stopwatch.StartNew();

sess.run(tf.global_variables_initializer());
var saver = tf.train.Saver(tf.global_variables());
@@ -212,14 +211,20 @@ namespace TensorFlowNET.Examples
Operation optimizer = graph.OperationByName("loss/Adam");
Tensor global_step = graph.OperationByName("Variable");
Tensor accuracy = graph.OperationByName("accuracy/accuracy");
stopwatch = Stopwatch.StartNew();

var sw = new Stopwatch();
sw.Start();

int step = 0;
foreach (var (x_batch, y_batch, total) in train_batches)
{
(_, step, loss_value) = sess.run((optimizer, global_step, loss),
(model_x, x_batch), (model_y, y_batch), (is_training, true));
if (step == 1 || step % 10 == 0)
Console.WriteLine($"Training on batch {step}/{total} loss: {loss_value.ToString("0.0000")}.");
if (step % 10 == 0)
{
Console.WriteLine($"Training on batch {step}/{total} loss: {loss_value.ToString("0.0000")} {sw.ElapsedMilliseconds}ms.");
sw.Restart();
}

if (step % 100 == 0)
{
@@ -242,7 +247,7 @@ namespace TensorFlowNET.Examples
var valid_accuracy = sum_accuracy / cnt;

print($"\nValidation Accuracy = {valid_accuracy.ToString("P")}\n");
// Save model
if (valid_accuracy > max_accuracy)
{


+ 1
- 0
test/TensorFlowNET.UnitTest/ImageTest.cs View File

@@ -24,6 +24,7 @@ namespace TensorFlowNET.UnitTest
contents = tf.read_file(imgPath);
}

[Ignore("")]
[TestMethod]
public void decode_image()
{


Loading…
Cancel
Save