diff --git a/docs/source/Queue.md b/docs/source/Queue.md
index bd73fd5a..b846278b 100644
--- a/docs/source/Queue.md
+++ b/docs/source/Queue.md
@@ -62,6 +62,33 @@ A FIFOQueue that supports batching variable-sized tensors by padding. A `Padding
A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument.
+```csharp
+[TestMethod]
+public void PriorityQueue()
+{
+ var queue = tf.PriorityQueue(3, tf.@string);
+ var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" });
+ var x = queue.dequeue();
+
+ using (var sess = tf.Session())
+ {
+ init.run();
+
+ // output will 2, 3, 4
+ var result = sess.run(x);
+ Assert.AreEqual(result[0].GetInt64(), 2L);
+
+ result = sess.run(x);
+ Assert.AreEqual(result[0].GetInt64(), 3L);
+
+ result = sess.run(x);
+ Assert.AreEqual(result[0].GetInt64(), 4L);
+ }
+}
+```
+
+
+
#### RandomShuffleQueue
A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument.
diff --git a/src/TensorFlowNET.Core/APIs/tf.queue.cs b/src/TensorFlowNET.Core/APIs/tf.queue.cs
index f81f5726..1a9641b4 100644
--- a/src/TensorFlowNET.Core/APIs/tf.queue.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.queue.cs
@@ -50,7 +50,7 @@ namespace Tensorflow
string shared_name = null,
string name = "padding_fifo_queue")
=> new PaddingFIFOQueue(capacity,
- new [] { dtype },
+ new[] { dtype },
new[] { shape },
shared_name: shared_name,
name: name);
@@ -86,7 +86,26 @@ namespace Tensorflow
=> new FIFOQueue(capacity,
new[] { dtype },
new[] { shape ?? new TensorShape() },
- new[] { name },
+ shared_name: shared_name,
+ name: name);
+
+ ///
+ /// Creates a queue that dequeues elements in a first-in first-out order.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public PriorityQueue PriorityQueue(int capacity,
+ TF_DataType dtype,
+ TensorShape shape = null,
+ string shared_name = null,
+ string name = "priority_queue")
+ => new PriorityQueue(capacity,
+ new[] { dtype },
+ new[] { shape ?? new TensorShape() },
shared_name: shared_name,
name: name);
}
diff --git a/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs
new file mode 100644
index 00000000..b41e1a0c
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs
@@ -0,0 +1,66 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Queues
+{
+ public class PriorityQueue : QueueBase
+ {
+ public PriorityQueue(int capacity,
+ TF_DataType[] dtypes,
+ TensorShape[] shapes,
+ string[] names = null,
+ string shared_name = null,
+ string name = "priority_queue")
+ : base(dtypes: dtypes, shapes: shapes, names: names)
+ {
+ _queue_ref = gen_data_flow_ops.priority_queue_v2(
+ component_types: dtypes,
+ shapes: shapes,
+ capacity: capacity,
+ shared_name: shared_name,
+ name: name);
+
+ _name = _queue_ref.op.name.Split('/').Last();
+
+ var dtypes1 = dtypes.ToList();
+ dtypes1.Insert(0, TF_DataType.TF_INT64);
+ _dtypes = dtypes1.ToArray();
+
+ var shapes1 = shapes.ToList();
+ shapes1.Insert(0, new TensorShape());
+ _shapes = shapes1.ToArray();
+ }
+
+ public Operation enqueue_many(long[] indexes, T[] vals, string name = null)
+ {
+ return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope =>
+ {
+ var vals_tensor1 = _check_enqueue_dtypes(indexes);
+ var vals_tensor2 = _check_enqueue_dtypes(vals);
+
+ var tensors = new List();
+ tensors.AddRange(vals_tensor1);
+ tensors.AddRange(vals_tensor2);
+
+ return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, tensors.ToArray(), name: scope);
+ });
+ }
+
+ public Tensor[] dequeue(string name = null)
+ {
+ Tensor[] ret;
+ if (name == null)
+ name = $"{_name}_Dequeue";
+
+ if (_queue_ref.dtype == TF_DataType.TF_RESOURCE)
+ ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name);
+ else
+ ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name);
+
+ return ret;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs
index 0eb5816d..38821d9d 100644
--- a/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs
+++ b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs
@@ -42,7 +42,7 @@ namespace Tensorflow.Queues
});
}
- private Tensor[] _check_enqueue_dtypes(object vals)
+ protected Tensor[] _check_enqueue_dtypes(object vals)
{
var tensors = new List();
@@ -56,12 +56,10 @@ namespace Tensorflow.Queues
}
break;
- case int[] vals1:
- tensors.Add(ops.convert_to_tensor(vals1, dtype: _dtypes[0], name: $"component_0"));
- break;
-
default:
- throw new NotImplementedException("");
+ var dtype1 = GetType().Name == "PriorityQueue" ? _dtypes[1] : _dtypes[0];
+ tensors.Add(ops.convert_to_tensor(vals, dtype: dtype1, name: $"component_0"));
+ break;
}
return tensors.ToArray();
diff --git a/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs
new file mode 100644
index 00000000..5765f081
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Tensorflow.Queues
+{
+ public class RandomShuffleQueue : QueueBase
+ {
+ public RandomShuffleQueue(int capacity,
+ TF_DataType[] dtypes,
+ TensorShape[] shapes,
+ string[] names = null,
+ string shared_name = null,
+ string name = "randomshuffle_fifo_queue")
+ : base(dtypes: dtypes, shapes: shapes, names: names)
+ {
+ _queue_ref = gen_data_flow_ops.padding_fifo_queue_v2(
+ component_types: dtypes,
+ shapes: shapes,
+ capacity: capacity,
+ shared_name: shared_name,
+ name: name);
+
+ _name = _queue_ref.op.name.Split('/').Last();
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
index 4fd394d2..b752268f 100644
--- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
@@ -77,6 +77,22 @@ namespace Tensorflow
return _op.output;
}
+ public static Tensor priority_queue_v2(TF_DataType[] component_types, TensorShape[] shapes,
+ int capacity = -1, string container = "", string shared_name = "",
+ string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("PriorityQueueV2", name, new
+ {
+ component_types,
+ shapes,
+ capacity,
+ container,
+ shared_name
+ });
+
+ return _op.output;
+ }
+
public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null)
{
var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new
diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
index 2f2d3d85..596a7532 100644
--- a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
+++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
@@ -10,11 +10,12 @@ namespace Tensorflow.Models.ObjectDetection
{
ImageResizerBuilder _image_resizer_builder;
FasterRCNNFeatureExtractor _feature_extractor;
- AnchorGeneratorBuilder anchor_generator_builder;
+ AnchorGeneratorBuilder _anchor_generator_builder;
public ModelBuilder()
{
_image_resizer_builder = new ImageResizerBuilder();
+ _anchor_generator_builder = new AnchorGeneratorBuilder();
}
///
@@ -51,7 +52,7 @@ namespace Tensorflow.Models.ObjectDetection
inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate);
var number_of_stages = frcnn_config.NumberOfStages;
- var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator);
+ var first_stage_anchor_generator = _anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator);
var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate;
return new FasterRCNNMetaArch(new FasterRCNNInitArgs
diff --git a/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config b/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config
index d5ec5f38..7458f4a5 100644
--- a/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config
+++ b/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config
@@ -1,143 +1,133 @@
-{
- "model": {
- "fasterRcnn": {
- "numClasses": 20,
- "imageResizer": {
- "keepAspectRatioResizer": {
- "minDimension": 600,
- "maxDimension": 1024
+# Faster R-CNN with Resnet-101 (v1), configured for Pascal VOC Dataset.
+# Users should configure the fine_tune_checkpoint field in the train config as
+# well as the label_map_path and input_path fields in the train_input_reader and
+# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
+# should be configured.
+
+model {
+ faster_rcnn {
+ num_classes: 20
+ image_resizer {
+ keep_aspect_ratio_resizer {
+ min_dimension: 600
+ max_dimension: 1024
+ }
+ }
+ feature_extractor {
+ type: 'faster_rcnn_resnet101'
+ first_stage_features_stride: 16
+ }
+ first_stage_anchor_generator {
+ grid_anchor_generator {
+ scales: [0.25, 0.5, 1.0, 2.0]
+ aspect_ratios: [0.5, 1.0, 2.0]
+ height_stride: 16
+ width_stride: 16
+ }
+ }
+ first_stage_box_predictor_conv_hyperparams {
+ op: CONV
+ regularizer {
+ l2_regularizer {
+ weight: 0.0
}
- },
- "featureExtractor": {
- "type": "faster_rcnn_resnet101",
- "firstStageFeaturesStride": 16
- },
- "firstStageAnchorGenerator": {
- "gridAnchorGenerator": {
- "heightStride": 16,
- "widthStride": 16,
- "scales": [
- 0.25,
- 0.5,
- 1.0,
- 2.0
- ],
- "aspectRatios": [
- 0.5,
- 1.0,
- 2.0
- ]
+ }
+ initializer {
+ truncated_normal_initializer {
+ stddev: 0.01
}
- },
- "firstStageBoxPredictorConvHyperparams": {
- "op": "CONV",
- "regularizer": {
- "l2Regularizer": {
- "weight": 0.0
- }
- },
- "initializer": {
- "truncatedNormalInitializer": {
- "stddev": 0.009999999776482582
+ }
+ }
+ first_stage_nms_score_threshold: 0.0
+ first_stage_nms_iou_threshold: 0.7
+ first_stage_max_proposals: 300
+ first_stage_localization_loss_weight: 2.0
+ first_stage_objectness_loss_weight: 1.0
+ initial_crop_size: 14
+ maxpool_kernel_size: 2
+ maxpool_stride: 2
+ second_stage_box_predictor {
+ mask_rcnn_box_predictor {
+ use_dropout: false
+ dropout_keep_probability: 1.0
+ fc_hyperparams {
+ op: FC
+ regularizer {
+ l2_regularizer {
+ weight: 0.0
+ }
}
- }
- },
- "firstStageNmsScoreThreshold": 0.0,
- "firstStageNmsIouThreshold": 0.699999988079071,
- "firstStageMaxProposals": 300,
- "firstStageLocalizationLossWeight": 2.0,
- "firstStageObjectnessLossWeight": 1.0,
- "initialCropSize": 14,
- "maxpoolKernelSize": 2,
- "maxpoolStride": 2,
- "secondStageBoxPredictor": {
- "maskRcnnBoxPredictor": {
- "fcHyperparams": {
- "op": "FC",
- "regularizer": {
- "l2Regularizer": {
- "weight": 0.0
- }
- },
- "initializer": {
- "varianceScalingInitializer": {
- "factor": 1.0,
- "uniform": true,
- "mode": "FAN_AVG"
- }
+ initializer {
+ variance_scaling_initializer {
+ factor: 1.0
+ uniform: true
+ mode: FAN_AVG
}
- },
- "useDropout": false,
- "dropoutKeepProbability": 1.0
+ }
}
- },
- "secondStagePostProcessing": {
- "batchNonMaxSuppression": {
- "scoreThreshold": 0.0,
- "iouThreshold": 0.6000000238418579,
- "maxDetectionsPerClass": 100,
- "maxTotalDetections": 300
- },
- "scoreConverter": "SOFTMAX"
- },
- "secondStageLocalizationLossWeight": 2.0,
- "secondStageClassificationLossWeight": 1.0
+ }
}
- },
- "trainConfig": {
- "batchSize": 1,
- "dataAugmentationOptions": [
- {
- "randomHorizontalFlip": {}
+ second_stage_post_processing {
+ batch_non_max_suppression {
+ score_threshold: 0.0
+ iou_threshold: 0.6
+ max_detections_per_class: 100
+ max_total_detections: 300
}
- ],
- "optimizer": {
- "momentumOptimizer": {
- "learningRate": {
- "manualStepLearningRate": {
- "initialLearningRate": 9.999999747378752e-05,
- "schedule": [
- {
- "step": 500000,
- "learningRate": 9.999999747378752e-06
- },
- {
- "step": 700000,
- "learningRate": 9.999999974752427e-07
- }
- ]
- }
- },
- "momentumOptimizerValue": 0.8999999761581421
- },
- "useMovingAverage": false
- },
- "gradientClippingByNorm": 10.0,
- "fineTuneCheckpoint": "D:/tmp/faster_rcnn_resnet101_coco/model.ckpt",
- "fromDetectionCheckpoint": true,
- "numSteps": 800000
- },
- "trainInputReader": {
- "labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt",
- "tfRecordInputReader": {
- "inputPath": [
- "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_train.record"
- ]
+ score_converter: SOFTMAX
}
- },
- "evalConfig": {
- "numExamples": 4952
- },
- "evalInputReader": [
- {
- "labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt",
- "shuffle": false,
- "numReaders": 1,
- "tfRecordInputReader": {
- "inputPath": [
- "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_val.record"
- ]
+ second_stage_localization_loss_weight: 2.0
+ second_stage_classification_loss_weight: 1.0
+ }
+}
+
+train_config: {
+ batch_size: 1
+ optimizer {
+ momentum_optimizer: {
+ learning_rate: {
+ manual_step_learning_rate {
+ initial_learning_rate: 0.0001
+ schedule {
+ step: 500000
+ learning_rate: .00001
+ }
+ schedule {
+ step: 700000
+ learning_rate: .000001
+ }
+ }
}
+ momentum_optimizer_value: 0.9
+ }
+ use_moving_average: false
+ }
+ gradient_clipping_by_norm: 10.0
+ fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
+ from_detection_checkpoint: true
+ num_steps: 800000
+ data_augmentation_options {
+ random_horizontal_flip {
}
- ]
-}
\ No newline at end of file
+ }
+}
+
+train_input_reader: {
+ tf_record_input_reader {
+ input_path: "PATH_TO_BE_CONFIGURED/pascal_train.record"
+ }
+ label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt"
+}
+
+eval_config: {
+ num_examples: 4952
+}
+
+eval_input_reader: {
+ tf_record_input_reader {
+ input_path: "PATH_TO_BE_CONFIGURED/pascal_val.record"
+ }
+ label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt"
+ shuffle: false
+ num_readers: 1
+}
diff --git a/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs b/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs
index a8b3876e..2a6a672e 100644
--- a/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs
+++ b/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs
@@ -1,4 +1,5 @@
-using System;
+using Protobuf.Text;
+using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
@@ -10,8 +11,8 @@ namespace Tensorflow.Models.ObjectDetection.Utils
{
public static TrainEvalPipelineConfig get_configs_from_pipeline_file(string pipeline_config_path)
{
- var json = File.ReadAllText(pipeline_config_path);
- var pipeline_config = TrainEvalPipelineConfig.Parser.ParseJson(json);
+ var config = File.ReadAllText(pipeline_config_path);
+ var pipeline_config = TrainEvalPipelineConfig.Parser.ParseText(config);
return pipeline_config;
}
diff --git a/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj b/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj
index 291c7c03..aae55be9 100644
--- a/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj
+++ b/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj
@@ -4,6 +4,16 @@
netcoreapp2.2
TensorFlow.Models
Tensorflow.Models
+ 0.0.1
+ Haiping Chen
+ SciSharp STACK
+ https://github.com/SciSharp/TensorFlow.NET
+ https://github.com/SciSharp/TensorFlow.NET
+ git
+ TensorFlow
+ Models and examples built with TensorFlow.
+ true
+ Apache2
@@ -16,6 +26,10 @@
+
+
+
+
diff --git a/test/TensorFlowNET.UnitTest/QueueTest.cs b/test/TensorFlowNET.UnitTest/QueueTest.cs
index 14afbae5..d546d961 100644
--- a/test/TensorFlowNET.UnitTest/QueueTest.cs
+++ b/test/TensorFlowNET.UnitTest/QueueTest.cs
@@ -70,5 +70,27 @@ namespace TensorFlowNET.UnitTest
// until queue has more element.
}
}
+
+ [TestMethod]
+ public void PriorityQueue()
+ {
+ var queue = tf.PriorityQueue(3, tf.@string);
+ var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" });
+ var x = queue.dequeue();
+
+ using (var sess = tf.Session())
+ {
+ init.run();
+
+ var result = sess.run(x);
+ Assert.AreEqual(result[0].GetInt64(), 2L);
+
+ result = sess.run(x);
+ Assert.AreEqual(result[0].GetInt64(), 3L);
+
+ result = sess.run(x);
+ Assert.AreEqual(result[0].GetInt64(), 4L);
+ }
+ }
}
}