| @@ -62,6 +62,33 @@ A FIFOQueue that supports batching variable-sized tensors by padding. A `Padding | |||
| A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. | |||
| ```csharp | |||
| [TestMethod] | |||
| public void PriorityQueue() | |||
| { | |||
| var queue = tf.PriorityQueue(3, tf.@string); | |||
| var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | |||
| var x = queue.dequeue(); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| init.run(); | |||
| // output will 2, 3, 4 | |||
| var result = sess.run(x); | |||
| Assert.AreEqual(result[0].GetInt64(), 2L); | |||
| result = sess.run(x); | |||
| Assert.AreEqual(result[0].GetInt64(), 3L); | |||
| result = sess.run(x); | |||
| Assert.AreEqual(result[0].GetInt64(), 4L); | |||
| } | |||
| } | |||
| ``` | |||
| #### RandomShuffleQueue | |||
| A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. | |||
| @@ -50,7 +50,7 @@ namespace Tensorflow | |||
| string shared_name = null, | |||
| string name = "padding_fifo_queue") | |||
| => new PaddingFIFOQueue(capacity, | |||
| new [] { dtype }, | |||
| new[] { dtype }, | |||
| new[] { shape }, | |||
| shared_name: shared_name, | |||
| name: name); | |||
| @@ -86,7 +86,26 @@ namespace Tensorflow | |||
| => new FIFOQueue(capacity, | |||
| new[] { dtype }, | |||
| new[] { shape ?? new TensorShape() }, | |||
| new[] { name }, | |||
| shared_name: shared_name, | |||
| name: name); | |||
| /// <summary> | |||
| /// Creates a queue that dequeues elements in a first-in first-out order. | |||
| /// </summary> | |||
| /// <param name="capacity"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="shape"></param> | |||
| /// <param name="shared_name"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public PriorityQueue PriorityQueue(int capacity, | |||
| TF_DataType dtype, | |||
| TensorShape shape = null, | |||
| string shared_name = null, | |||
| string name = "priority_queue") | |||
| => new PriorityQueue(capacity, | |||
| new[] { dtype }, | |||
| new[] { shape ?? new TensorShape() }, | |||
| shared_name: shared_name, | |||
| name: name); | |||
| } | |||
| @@ -0,0 +1,66 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Queues | |||
| { | |||
| public class PriorityQueue : QueueBase | |||
| { | |||
| public PriorityQueue(int capacity, | |||
| TF_DataType[] dtypes, | |||
| TensorShape[] shapes, | |||
| string[] names = null, | |||
| string shared_name = null, | |||
| string name = "priority_queue") | |||
| : base(dtypes: dtypes, shapes: shapes, names: names) | |||
| { | |||
| _queue_ref = gen_data_flow_ops.priority_queue_v2( | |||
| component_types: dtypes, | |||
| shapes: shapes, | |||
| capacity: capacity, | |||
| shared_name: shared_name, | |||
| name: name); | |||
| _name = _queue_ref.op.name.Split('/').Last(); | |||
| var dtypes1 = dtypes.ToList(); | |||
| dtypes1.Insert(0, TF_DataType.TF_INT64); | |||
| _dtypes = dtypes1.ToArray(); | |||
| var shapes1 = shapes.ToList(); | |||
| shapes1.Insert(0, new TensorShape()); | |||
| _shapes = shapes1.ToArray(); | |||
| } | |||
| public Operation enqueue_many<T>(long[] indexes, T[] vals, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope => | |||
| { | |||
| var vals_tensor1 = _check_enqueue_dtypes(indexes); | |||
| var vals_tensor2 = _check_enqueue_dtypes(vals); | |||
| var tensors = new List<Tensor>(); | |||
| tensors.AddRange(vals_tensor1); | |||
| tensors.AddRange(vals_tensor2); | |||
| return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, tensors.ToArray(), name: scope); | |||
| }); | |||
| } | |||
| public Tensor[] dequeue(string name = null) | |||
| { | |||
| Tensor[] ret; | |||
| if (name == null) | |||
| name = $"{_name}_Dequeue"; | |||
| if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) | |||
| ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name); | |||
| else | |||
| ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name); | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow.Queues | |||
| }); | |||
| } | |||
| private Tensor[] _check_enqueue_dtypes(object vals) | |||
| protected Tensor[] _check_enqueue_dtypes(object vals) | |||
| { | |||
| var tensors = new List<Tensor>(); | |||
| @@ -56,12 +56,10 @@ namespace Tensorflow.Queues | |||
| } | |||
| break; | |||
| case int[] vals1: | |||
| tensors.Add(ops.convert_to_tensor(vals1, dtype: _dtypes[0], name: $"component_0")); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException(""); | |||
| var dtype1 = GetType().Name == "PriorityQueue" ? _dtypes[1] : _dtypes[0]; | |||
| tensors.Add(ops.convert_to_tensor(vals, dtype: dtype1, name: $"component_0")); | |||
| break; | |||
| } | |||
| return tensors.ToArray(); | |||
| @@ -0,0 +1,28 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow.Queues | |||
| { | |||
| public class RandomShuffleQueue : QueueBase | |||
| { | |||
| public RandomShuffleQueue(int capacity, | |||
| TF_DataType[] dtypes, | |||
| TensorShape[] shapes, | |||
| string[] names = null, | |||
| string shared_name = null, | |||
| string name = "randomshuffle_fifo_queue") | |||
| : base(dtypes: dtypes, shapes: shapes, names: names) | |||
| { | |||
| _queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( | |||
| component_types: dtypes, | |||
| shapes: shapes, | |||
| capacity: capacity, | |||
| shared_name: shared_name, | |||
| name: name); | |||
| _name = _queue_ref.op.name.Split('/').Last(); | |||
| } | |||
| } | |||
| } | |||
| @@ -77,6 +77,22 @@ namespace Tensorflow | |||
| return _op.output; | |||
| } | |||
| public static Tensor priority_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, | |||
| int capacity = -1, string container = "", string shared_name = "", | |||
| string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("PriorityQueueV2", name, new | |||
| { | |||
| component_types, | |||
| shapes, | |||
| capacity, | |||
| container, | |||
| shared_name | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new | |||
| @@ -10,11 +10,12 @@ namespace Tensorflow.Models.ObjectDetection | |||
| { | |||
| ImageResizerBuilder _image_resizer_builder; | |||
| FasterRCNNFeatureExtractor _feature_extractor; | |||
| AnchorGeneratorBuilder anchor_generator_builder; | |||
| AnchorGeneratorBuilder _anchor_generator_builder; | |||
| public ModelBuilder() | |||
| { | |||
| _image_resizer_builder = new ImageResizerBuilder(); | |||
| _anchor_generator_builder = new AnchorGeneratorBuilder(); | |||
| } | |||
| /// <summary> | |||
| @@ -51,7 +52,7 @@ namespace Tensorflow.Models.ObjectDetection | |||
| inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate); | |||
| var number_of_stages = frcnn_config.NumberOfStages; | |||
| var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); | |||
| var first_stage_anchor_generator = _anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); | |||
| var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | |||
| return new FasterRCNNMetaArch(new FasterRCNNInitArgs | |||
| @@ -1,143 +1,133 @@ | |||
| { | |||
| "model": { | |||
| "fasterRcnn": { | |||
| "numClasses": 20, | |||
| "imageResizer": { | |||
| "keepAspectRatioResizer": { | |||
| "minDimension": 600, | |||
| "maxDimension": 1024 | |||
| # Faster R-CNN with Resnet-101 (v1), configured for Pascal VOC Dataset. | |||
| # Users should configure the fine_tune_checkpoint field in the train config as | |||
| # well as the label_map_path and input_path fields in the train_input_reader and | |||
| # eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that | |||
| # should be configured. | |||
| model { | |||
| faster_rcnn { | |||
| num_classes: 20 | |||
| image_resizer { | |||
| keep_aspect_ratio_resizer { | |||
| min_dimension: 600 | |||
| max_dimension: 1024 | |||
| } | |||
| } | |||
| feature_extractor { | |||
| type: 'faster_rcnn_resnet101' | |||
| first_stage_features_stride: 16 | |||
| } | |||
| first_stage_anchor_generator { | |||
| grid_anchor_generator { | |||
| scales: [0.25, 0.5, 1.0, 2.0] | |||
| aspect_ratios: [0.5, 1.0, 2.0] | |||
| height_stride: 16 | |||
| width_stride: 16 | |||
| } | |||
| } | |||
| first_stage_box_predictor_conv_hyperparams { | |||
| op: CONV | |||
| regularizer { | |||
| l2_regularizer { | |||
| weight: 0.0 | |||
| } | |||
| }, | |||
| "featureExtractor": { | |||
| "type": "faster_rcnn_resnet101", | |||
| "firstStageFeaturesStride": 16 | |||
| }, | |||
| "firstStageAnchorGenerator": { | |||
| "gridAnchorGenerator": { | |||
| "heightStride": 16, | |||
| "widthStride": 16, | |||
| "scales": [ | |||
| 0.25, | |||
| 0.5, | |||
| 1.0, | |||
| 2.0 | |||
| ], | |||
| "aspectRatios": [ | |||
| 0.5, | |||
| 1.0, | |||
| 2.0 | |||
| ] | |||
| } | |||
| initializer { | |||
| truncated_normal_initializer { | |||
| stddev: 0.01 | |||
| } | |||
| }, | |||
| "firstStageBoxPredictorConvHyperparams": { | |||
| "op": "CONV", | |||
| "regularizer": { | |||
| "l2Regularizer": { | |||
| "weight": 0.0 | |||
| } | |||
| }, | |||
| "initializer": { | |||
| "truncatedNormalInitializer": { | |||
| "stddev": 0.009999999776482582 | |||
| } | |||
| } | |||
| first_stage_nms_score_threshold: 0.0 | |||
| first_stage_nms_iou_threshold: 0.7 | |||
| first_stage_max_proposals: 300 | |||
| first_stage_localization_loss_weight: 2.0 | |||
| first_stage_objectness_loss_weight: 1.0 | |||
| initial_crop_size: 14 | |||
| maxpool_kernel_size: 2 | |||
| maxpool_stride: 2 | |||
| second_stage_box_predictor { | |||
| mask_rcnn_box_predictor { | |||
| use_dropout: false | |||
| dropout_keep_probability: 1.0 | |||
| fc_hyperparams { | |||
| op: FC | |||
| regularizer { | |||
| l2_regularizer { | |||
| weight: 0.0 | |||
| } | |||
| } | |||
| } | |||
| }, | |||
| "firstStageNmsScoreThreshold": 0.0, | |||
| "firstStageNmsIouThreshold": 0.699999988079071, | |||
| "firstStageMaxProposals": 300, | |||
| "firstStageLocalizationLossWeight": 2.0, | |||
| "firstStageObjectnessLossWeight": 1.0, | |||
| "initialCropSize": 14, | |||
| "maxpoolKernelSize": 2, | |||
| "maxpoolStride": 2, | |||
| "secondStageBoxPredictor": { | |||
| "maskRcnnBoxPredictor": { | |||
| "fcHyperparams": { | |||
| "op": "FC", | |||
| "regularizer": { | |||
| "l2Regularizer": { | |||
| "weight": 0.0 | |||
| } | |||
| }, | |||
| "initializer": { | |||
| "varianceScalingInitializer": { | |||
| "factor": 1.0, | |||
| "uniform": true, | |||
| "mode": "FAN_AVG" | |||
| } | |||
| initializer { | |||
| variance_scaling_initializer { | |||
| factor: 1.0 | |||
| uniform: true | |||
| mode: FAN_AVG | |||
| } | |||
| }, | |||
| "useDropout": false, | |||
| "dropoutKeepProbability": 1.0 | |||
| } | |||
| } | |||
| }, | |||
| "secondStagePostProcessing": { | |||
| "batchNonMaxSuppression": { | |||
| "scoreThreshold": 0.0, | |||
| "iouThreshold": 0.6000000238418579, | |||
| "maxDetectionsPerClass": 100, | |||
| "maxTotalDetections": 300 | |||
| }, | |||
| "scoreConverter": "SOFTMAX" | |||
| }, | |||
| "secondStageLocalizationLossWeight": 2.0, | |||
| "secondStageClassificationLossWeight": 1.0 | |||
| } | |||
| } | |||
| }, | |||
| "trainConfig": { | |||
| "batchSize": 1, | |||
| "dataAugmentationOptions": [ | |||
| { | |||
| "randomHorizontalFlip": {} | |||
| second_stage_post_processing { | |||
| batch_non_max_suppression { | |||
| score_threshold: 0.0 | |||
| iou_threshold: 0.6 | |||
| max_detections_per_class: 100 | |||
| max_total_detections: 300 | |||
| } | |||
| ], | |||
| "optimizer": { | |||
| "momentumOptimizer": { | |||
| "learningRate": { | |||
| "manualStepLearningRate": { | |||
| "initialLearningRate": 9.999999747378752e-05, | |||
| "schedule": [ | |||
| { | |||
| "step": 500000, | |||
| "learningRate": 9.999999747378752e-06 | |||
| }, | |||
| { | |||
| "step": 700000, | |||
| "learningRate": 9.999999974752427e-07 | |||
| } | |||
| ] | |||
| } | |||
| }, | |||
| "momentumOptimizerValue": 0.8999999761581421 | |||
| }, | |||
| "useMovingAverage": false | |||
| }, | |||
| "gradientClippingByNorm": 10.0, | |||
| "fineTuneCheckpoint": "D:/tmp/faster_rcnn_resnet101_coco/model.ckpt", | |||
| "fromDetectionCheckpoint": true, | |||
| "numSteps": 800000 | |||
| }, | |||
| "trainInputReader": { | |||
| "labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt", | |||
| "tfRecordInputReader": { | |||
| "inputPath": [ | |||
| "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_train.record" | |||
| ] | |||
| score_converter: SOFTMAX | |||
| } | |||
| }, | |||
| "evalConfig": { | |||
| "numExamples": 4952 | |||
| }, | |||
| "evalInputReader": [ | |||
| { | |||
| "labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt", | |||
| "shuffle": false, | |||
| "numReaders": 1, | |||
| "tfRecordInputReader": { | |||
| "inputPath": [ | |||
| "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_val.record" | |||
| ] | |||
| second_stage_localization_loss_weight: 2.0 | |||
| second_stage_classification_loss_weight: 1.0 | |||
| } | |||
| } | |||
| train_config: { | |||
| batch_size: 1 | |||
| optimizer { | |||
| momentum_optimizer: { | |||
| learning_rate: { | |||
| manual_step_learning_rate { | |||
| initial_learning_rate: 0.0001 | |||
| schedule { | |||
| step: 500000 | |||
| learning_rate: .00001 | |||
| } | |||
| schedule { | |||
| step: 700000 | |||
| learning_rate: .000001 | |||
| } | |||
| } | |||
| } | |||
| momentum_optimizer_value: 0.9 | |||
| } | |||
| use_moving_average: false | |||
| } | |||
| gradient_clipping_by_norm: 10.0 | |||
| fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt" | |||
| from_detection_checkpoint: true | |||
| num_steps: 800000 | |||
| data_augmentation_options { | |||
| random_horizontal_flip { | |||
| } | |||
| ] | |||
| } | |||
| } | |||
| } | |||
| train_input_reader: { | |||
| tf_record_input_reader { | |||
| input_path: "PATH_TO_BE_CONFIGURED/pascal_train.record" | |||
| } | |||
| label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" | |||
| } | |||
| eval_config: { | |||
| num_examples: 4952 | |||
| } | |||
| eval_input_reader: { | |||
| tf_record_input_reader { | |||
| input_path: "PATH_TO_BE_CONFIGURED/pascal_val.record" | |||
| } | |||
| label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" | |||
| shuffle: false | |||
| num_readers: 1 | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using Protobuf.Text; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text; | |||
| @@ -10,8 +11,8 @@ namespace Tensorflow.Models.ObjectDetection.Utils | |||
| { | |||
| public static TrainEvalPipelineConfig get_configs_from_pipeline_file(string pipeline_config_path) | |||
| { | |||
| var json = File.ReadAllText(pipeline_config_path); | |||
| var pipeline_config = TrainEvalPipelineConfig.Parser.ParseJson(json); | |||
| var config = File.ReadAllText(pipeline_config_path); | |||
| var pipeline_config = TrainEvalPipelineConfig.Parser.ParseText(config); | |||
| return pipeline_config; | |||
| } | |||
| @@ -4,6 +4,16 @@ | |||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||
| <AssemblyName>TensorFlow.Models</AssemblyName> | |||
| <RootNamespace>Tensorflow.Models</RootNamespace> | |||
| <Version>0.0.1</Version> | |||
| <Authors>Haiping Chen</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| <PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl> | |||
| <RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | |||
| <RepositoryType>git</RepositoryType> | |||
| <PackageTags>TensorFlow</PackageTags> | |||
| <Description>Models and examples built with TensorFlow.</Description> | |||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
| <Copyright>Apache2</Copyright> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| @@ -16,6 +26,10 @@ | |||
| </Content> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="Protobuf.Text" Version="0.3.1" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||
| </ItemGroup> | |||
| @@ -70,5 +70,27 @@ namespace TensorFlowNET.UnitTest | |||
| // until queue has more element. | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void PriorityQueue() | |||
| { | |||
| var queue = tf.PriorityQueue(3, tf.@string); | |||
| var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | |||
| var x = queue.dequeue(); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| init.run(); | |||
| var result = sess.run(x); | |||
| Assert.AreEqual(result[0].GetInt64(), 2L); | |||
| result = sess.run(x); | |||
| Assert.AreEqual(result[0].GetInt64(), 3L); | |||
| result = sess.run(x); | |||
| Assert.AreEqual(result[0].GetInt64(), 4L); | |||
| } | |||
| } | |||
| } | |||
| } | |||