| @@ -62,6 +62,33 @@ A FIFOQueue that supports batching variable-sized tensors by padding. A `Padding | |||||
| A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. | A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. | ||||
| ```csharp | |||||
| [TestMethod] | |||||
| public void PriorityQueue() | |||||
| { | |||||
| var queue = tf.PriorityQueue(3, tf.@string); | |||||
| var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | |||||
| var x = queue.dequeue(); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| init.run(); | |||||
| // output will 2, 3, 4 | |||||
| var result = sess.run(x); | |||||
| Assert.AreEqual(result[0].GetInt64(), 2L); | |||||
| result = sess.run(x); | |||||
| Assert.AreEqual(result[0].GetInt64(), 3L); | |||||
| result = sess.run(x); | |||||
| Assert.AreEqual(result[0].GetInt64(), 4L); | |||||
| } | |||||
| } | |||||
| ``` | |||||
| #### RandomShuffleQueue | #### RandomShuffleQueue | ||||
| A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. | A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. | ||||
| @@ -50,7 +50,7 @@ namespace Tensorflow | |||||
| string shared_name = null, | string shared_name = null, | ||||
| string name = "padding_fifo_queue") | string name = "padding_fifo_queue") | ||||
| => new PaddingFIFOQueue(capacity, | => new PaddingFIFOQueue(capacity, | ||||
| new [] { dtype }, | |||||
| new[] { dtype }, | |||||
| new[] { shape }, | new[] { shape }, | ||||
| shared_name: shared_name, | shared_name: shared_name, | ||||
| name: name); | name: name); | ||||
| @@ -86,7 +86,26 @@ namespace Tensorflow | |||||
| => new FIFOQueue(capacity, | => new FIFOQueue(capacity, | ||||
| new[] { dtype }, | new[] { dtype }, | ||||
| new[] { shape ?? new TensorShape() }, | new[] { shape ?? new TensorShape() }, | ||||
| new[] { name }, | |||||
| shared_name: shared_name, | |||||
| name: name); | |||||
| /// <summary> | |||||
| /// Creates a queue that dequeues elements in a first-in first-out order. | |||||
| /// </summary> | |||||
| /// <param name="capacity"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="shape"></param> | |||||
| /// <param name="shared_name"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public PriorityQueue PriorityQueue(int capacity, | |||||
| TF_DataType dtype, | |||||
| TensorShape shape = null, | |||||
| string shared_name = null, | |||||
| string name = "priority_queue") | |||||
| => new PriorityQueue(capacity, | |||||
| new[] { dtype }, | |||||
| new[] { shape ?? new TensorShape() }, | |||||
| shared_name: shared_name, | shared_name: shared_name, | ||||
| name: name); | name: name); | ||||
| } | } | ||||
| @@ -0,0 +1,66 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Queues | |||||
| { | |||||
| public class PriorityQueue : QueueBase | |||||
| { | |||||
| public PriorityQueue(int capacity, | |||||
| TF_DataType[] dtypes, | |||||
| TensorShape[] shapes, | |||||
| string[] names = null, | |||||
| string shared_name = null, | |||||
| string name = "priority_queue") | |||||
| : base(dtypes: dtypes, shapes: shapes, names: names) | |||||
| { | |||||
| _queue_ref = gen_data_flow_ops.priority_queue_v2( | |||||
| component_types: dtypes, | |||||
| shapes: shapes, | |||||
| capacity: capacity, | |||||
| shared_name: shared_name, | |||||
| name: name); | |||||
| _name = _queue_ref.op.name.Split('/').Last(); | |||||
| var dtypes1 = dtypes.ToList(); | |||||
| dtypes1.Insert(0, TF_DataType.TF_INT64); | |||||
| _dtypes = dtypes1.ToArray(); | |||||
| var shapes1 = shapes.ToList(); | |||||
| shapes1.Insert(0, new TensorShape()); | |||||
| _shapes = shapes1.ToArray(); | |||||
| } | |||||
| public Operation enqueue_many<T>(long[] indexes, T[] vals, string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope => | |||||
| { | |||||
| var vals_tensor1 = _check_enqueue_dtypes(indexes); | |||||
| var vals_tensor2 = _check_enqueue_dtypes(vals); | |||||
| var tensors = new List<Tensor>(); | |||||
| tensors.AddRange(vals_tensor1); | |||||
| tensors.AddRange(vals_tensor2); | |||||
| return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, tensors.ToArray(), name: scope); | |||||
| }); | |||||
| } | |||||
| public Tensor[] dequeue(string name = null) | |||||
| { | |||||
| Tensor[] ret; | |||||
| if (name == null) | |||||
| name = $"{_name}_Dequeue"; | |||||
| if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) | |||||
| ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name); | |||||
| else | |||||
| ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -42,7 +42,7 @@ namespace Tensorflow.Queues | |||||
| }); | }); | ||||
| } | } | ||||
| private Tensor[] _check_enqueue_dtypes(object vals) | |||||
| protected Tensor[] _check_enqueue_dtypes(object vals) | |||||
| { | { | ||||
| var tensors = new List<Tensor>(); | var tensors = new List<Tensor>(); | ||||
| @@ -56,12 +56,10 @@ namespace Tensorflow.Queues | |||||
| } | } | ||||
| break; | break; | ||||
| case int[] vals1: | |||||
| tensors.Add(ops.convert_to_tensor(vals1, dtype: _dtypes[0], name: $"component_0")); | |||||
| break; | |||||
| default: | default: | ||||
| throw new NotImplementedException(""); | |||||
| var dtype1 = GetType().Name == "PriorityQueue" ? _dtypes[1] : _dtypes[0]; | |||||
| tensors.Add(ops.convert_to_tensor(vals, dtype: dtype1, name: $"component_0")); | |||||
| break; | |||||
| } | } | ||||
| return tensors.ToArray(); | return tensors.ToArray(); | ||||
| @@ -0,0 +1,28 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Queues | |||||
| { | |||||
| public class RandomShuffleQueue : QueueBase | |||||
| { | |||||
| public RandomShuffleQueue(int capacity, | |||||
| TF_DataType[] dtypes, | |||||
| TensorShape[] shapes, | |||||
| string[] names = null, | |||||
| string shared_name = null, | |||||
| string name = "randomshuffle_fifo_queue") | |||||
| : base(dtypes: dtypes, shapes: shapes, names: names) | |||||
| { | |||||
| _queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( | |||||
| component_types: dtypes, | |||||
| shapes: shapes, | |||||
| capacity: capacity, | |||||
| shared_name: shared_name, | |||||
| name: name); | |||||
| _name = _queue_ref.op.name.Split('/').Last(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -77,6 +77,22 @@ namespace Tensorflow | |||||
| return _op.output; | return _op.output; | ||||
| } | } | ||||
| public static Tensor priority_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, | |||||
| int capacity = -1, string container = "", string shared_name = "", | |||||
| string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("PriorityQueueV2", name, new | |||||
| { | |||||
| component_types, | |||||
| shapes, | |||||
| capacity, | |||||
| container, | |||||
| shared_name | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) | public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new | var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new | ||||
| @@ -10,11 +10,12 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| { | { | ||||
| ImageResizerBuilder _image_resizer_builder; | ImageResizerBuilder _image_resizer_builder; | ||||
| FasterRCNNFeatureExtractor _feature_extractor; | FasterRCNNFeatureExtractor _feature_extractor; | ||||
| AnchorGeneratorBuilder anchor_generator_builder; | |||||
| AnchorGeneratorBuilder _anchor_generator_builder; | |||||
| public ModelBuilder() | public ModelBuilder() | ||||
| { | { | ||||
| _image_resizer_builder = new ImageResizerBuilder(); | _image_resizer_builder = new ImageResizerBuilder(); | ||||
| _anchor_generator_builder = new AnchorGeneratorBuilder(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -51,7 +52,7 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate); | inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate); | ||||
| var number_of_stages = frcnn_config.NumberOfStages; | var number_of_stages = frcnn_config.NumberOfStages; | ||||
| var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); | |||||
| var first_stage_anchor_generator = _anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); | |||||
| var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | ||||
| return new FasterRCNNMetaArch(new FasterRCNNInitArgs | return new FasterRCNNMetaArch(new FasterRCNNInitArgs | ||||
| @@ -1,143 +1,133 @@ | |||||
| { | |||||
| "model": { | |||||
| "fasterRcnn": { | |||||
| "numClasses": 20, | |||||
| "imageResizer": { | |||||
| "keepAspectRatioResizer": { | |||||
| "minDimension": 600, | |||||
| "maxDimension": 1024 | |||||
| # Faster R-CNN with Resnet-101 (v1), configured for Pascal VOC Dataset. | |||||
| # Users should configure the fine_tune_checkpoint field in the train config as | |||||
| # well as the label_map_path and input_path fields in the train_input_reader and | |||||
| # eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that | |||||
| # should be configured. | |||||
| model { | |||||
| faster_rcnn { | |||||
| num_classes: 20 | |||||
| image_resizer { | |||||
| keep_aspect_ratio_resizer { | |||||
| min_dimension: 600 | |||||
| max_dimension: 1024 | |||||
| } | |||||
| } | |||||
| feature_extractor { | |||||
| type: 'faster_rcnn_resnet101' | |||||
| first_stage_features_stride: 16 | |||||
| } | |||||
| first_stage_anchor_generator { | |||||
| grid_anchor_generator { | |||||
| scales: [0.25, 0.5, 1.0, 2.0] | |||||
| aspect_ratios: [0.5, 1.0, 2.0] | |||||
| height_stride: 16 | |||||
| width_stride: 16 | |||||
| } | |||||
| } | |||||
| first_stage_box_predictor_conv_hyperparams { | |||||
| op: CONV | |||||
| regularizer { | |||||
| l2_regularizer { | |||||
| weight: 0.0 | |||||
| } | } | ||||
| }, | |||||
| "featureExtractor": { | |||||
| "type": "faster_rcnn_resnet101", | |||||
| "firstStageFeaturesStride": 16 | |||||
| }, | |||||
| "firstStageAnchorGenerator": { | |||||
| "gridAnchorGenerator": { | |||||
| "heightStride": 16, | |||||
| "widthStride": 16, | |||||
| "scales": [ | |||||
| 0.25, | |||||
| 0.5, | |||||
| 1.0, | |||||
| 2.0 | |||||
| ], | |||||
| "aspectRatios": [ | |||||
| 0.5, | |||||
| 1.0, | |||||
| 2.0 | |||||
| ] | |||||
| } | |||||
| initializer { | |||||
| truncated_normal_initializer { | |||||
| stddev: 0.01 | |||||
| } | } | ||||
| }, | |||||
| "firstStageBoxPredictorConvHyperparams": { | |||||
| "op": "CONV", | |||||
| "regularizer": { | |||||
| "l2Regularizer": { | |||||
| "weight": 0.0 | |||||
| } | |||||
| }, | |||||
| "initializer": { | |||||
| "truncatedNormalInitializer": { | |||||
| "stddev": 0.009999999776482582 | |||||
| } | |||||
| } | |||||
| first_stage_nms_score_threshold: 0.0 | |||||
| first_stage_nms_iou_threshold: 0.7 | |||||
| first_stage_max_proposals: 300 | |||||
| first_stage_localization_loss_weight: 2.0 | |||||
| first_stage_objectness_loss_weight: 1.0 | |||||
| initial_crop_size: 14 | |||||
| maxpool_kernel_size: 2 | |||||
| maxpool_stride: 2 | |||||
| second_stage_box_predictor { | |||||
| mask_rcnn_box_predictor { | |||||
| use_dropout: false | |||||
| dropout_keep_probability: 1.0 | |||||
| fc_hyperparams { | |||||
| op: FC | |||||
| regularizer { | |||||
| l2_regularizer { | |||||
| weight: 0.0 | |||||
| } | |||||
| } | } | ||||
| } | |||||
| }, | |||||
| "firstStageNmsScoreThreshold": 0.0, | |||||
| "firstStageNmsIouThreshold": 0.699999988079071, | |||||
| "firstStageMaxProposals": 300, | |||||
| "firstStageLocalizationLossWeight": 2.0, | |||||
| "firstStageObjectnessLossWeight": 1.0, | |||||
| "initialCropSize": 14, | |||||
| "maxpoolKernelSize": 2, | |||||
| "maxpoolStride": 2, | |||||
| "secondStageBoxPredictor": { | |||||
| "maskRcnnBoxPredictor": { | |||||
| "fcHyperparams": { | |||||
| "op": "FC", | |||||
| "regularizer": { | |||||
| "l2Regularizer": { | |||||
| "weight": 0.0 | |||||
| } | |||||
| }, | |||||
| "initializer": { | |||||
| "varianceScalingInitializer": { | |||||
| "factor": 1.0, | |||||
| "uniform": true, | |||||
| "mode": "FAN_AVG" | |||||
| } | |||||
| initializer { | |||||
| variance_scaling_initializer { | |||||
| factor: 1.0 | |||||
| uniform: true | |||||
| mode: FAN_AVG | |||||
| } | } | ||||
| }, | |||||
| "useDropout": false, | |||||
| "dropoutKeepProbability": 1.0 | |||||
| } | |||||
| } | } | ||||
| }, | |||||
| "secondStagePostProcessing": { | |||||
| "batchNonMaxSuppression": { | |||||
| "scoreThreshold": 0.0, | |||||
| "iouThreshold": 0.6000000238418579, | |||||
| "maxDetectionsPerClass": 100, | |||||
| "maxTotalDetections": 300 | |||||
| }, | |||||
| "scoreConverter": "SOFTMAX" | |||||
| }, | |||||
| "secondStageLocalizationLossWeight": 2.0, | |||||
| "secondStageClassificationLossWeight": 1.0 | |||||
| } | |||||
| } | } | ||||
| }, | |||||
| "trainConfig": { | |||||
| "batchSize": 1, | |||||
| "dataAugmentationOptions": [ | |||||
| { | |||||
| "randomHorizontalFlip": {} | |||||
| second_stage_post_processing { | |||||
| batch_non_max_suppression { | |||||
| score_threshold: 0.0 | |||||
| iou_threshold: 0.6 | |||||
| max_detections_per_class: 100 | |||||
| max_total_detections: 300 | |||||
| } | } | ||||
| ], | |||||
| "optimizer": { | |||||
| "momentumOptimizer": { | |||||
| "learningRate": { | |||||
| "manualStepLearningRate": { | |||||
| "initialLearningRate": 9.999999747378752e-05, | |||||
| "schedule": [ | |||||
| { | |||||
| "step": 500000, | |||||
| "learningRate": 9.999999747378752e-06 | |||||
| }, | |||||
| { | |||||
| "step": 700000, | |||||
| "learningRate": 9.999999974752427e-07 | |||||
| } | |||||
| ] | |||||
| } | |||||
| }, | |||||
| "momentumOptimizerValue": 0.8999999761581421 | |||||
| }, | |||||
| "useMovingAverage": false | |||||
| }, | |||||
| "gradientClippingByNorm": 10.0, | |||||
| "fineTuneCheckpoint": "D:/tmp/faster_rcnn_resnet101_coco/model.ckpt", | |||||
| "fromDetectionCheckpoint": true, | |||||
| "numSteps": 800000 | |||||
| }, | |||||
| "trainInputReader": { | |||||
| "labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt", | |||||
| "tfRecordInputReader": { | |||||
| "inputPath": [ | |||||
| "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_train.record" | |||||
| ] | |||||
| score_converter: SOFTMAX | |||||
| } | } | ||||
| }, | |||||
| "evalConfig": { | |||||
| "numExamples": 4952 | |||||
| }, | |||||
| "evalInputReader": [ | |||||
| { | |||||
| "labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt", | |||||
| "shuffle": false, | |||||
| "numReaders": 1, | |||||
| "tfRecordInputReader": { | |||||
| "inputPath": [ | |||||
| "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_val.record" | |||||
| ] | |||||
| second_stage_localization_loss_weight: 2.0 | |||||
| second_stage_classification_loss_weight: 1.0 | |||||
| } | |||||
| } | |||||
| train_config: { | |||||
| batch_size: 1 | |||||
| optimizer { | |||||
| momentum_optimizer: { | |||||
| learning_rate: { | |||||
| manual_step_learning_rate { | |||||
| initial_learning_rate: 0.0001 | |||||
| schedule { | |||||
| step: 500000 | |||||
| learning_rate: .00001 | |||||
| } | |||||
| schedule { | |||||
| step: 700000 | |||||
| learning_rate: .000001 | |||||
| } | |||||
| } | |||||
| } | } | ||||
| momentum_optimizer_value: 0.9 | |||||
| } | |||||
| use_moving_average: false | |||||
| } | |||||
| gradient_clipping_by_norm: 10.0 | |||||
| fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt" | |||||
| from_detection_checkpoint: true | |||||
| num_steps: 800000 | |||||
| data_augmentation_options { | |||||
| random_horizontal_flip { | |||||
| } | } | ||||
| ] | |||||
| } | |||||
| } | |||||
| } | |||||
| train_input_reader: { | |||||
| tf_record_input_reader { | |||||
| input_path: "PATH_TO_BE_CONFIGURED/pascal_train.record" | |||||
| } | |||||
| label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" | |||||
| } | |||||
| eval_config: { | |||||
| num_examples: 4952 | |||||
| } | |||||
| eval_input_reader: { | |||||
| tf_record_input_reader { | |||||
| input_path: "PATH_TO_BE_CONFIGURED/pascal_val.record" | |||||
| } | |||||
| label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" | |||||
| shuffle: false | |||||
| num_readers: 1 | |||||
| } | |||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using Protobuf.Text; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -10,8 +11,8 @@ namespace Tensorflow.Models.ObjectDetection.Utils | |||||
| { | { | ||||
| public static TrainEvalPipelineConfig get_configs_from_pipeline_file(string pipeline_config_path) | public static TrainEvalPipelineConfig get_configs_from_pipeline_file(string pipeline_config_path) | ||||
| { | { | ||||
| var json = File.ReadAllText(pipeline_config_path); | |||||
| var pipeline_config = TrainEvalPipelineConfig.Parser.ParseJson(json); | |||||
| var config = File.ReadAllText(pipeline_config_path); | |||||
| var pipeline_config = TrainEvalPipelineConfig.Parser.ParseText(config); | |||||
| return pipeline_config; | return pipeline_config; | ||||
| } | } | ||||
| @@ -4,6 +4,16 @@ | |||||
| <TargetFramework>netcoreapp2.2</TargetFramework> | <TargetFramework>netcoreapp2.2</TargetFramework> | ||||
| <AssemblyName>TensorFlow.Models</AssemblyName> | <AssemblyName>TensorFlow.Models</AssemblyName> | ||||
| <RootNamespace>Tensorflow.Models</RootNamespace> | <RootNamespace>Tensorflow.Models</RootNamespace> | ||||
| <Version>0.0.1</Version> | |||||
| <Authors>Haiping Chen</Authors> | |||||
| <Company>SciSharp STACK</Company> | |||||
| <PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl> | |||||
| <RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | |||||
| <RepositoryType>git</RepositoryType> | |||||
| <PackageTags>TensorFlow</PackageTags> | |||||
| <Description>Models and examples built with TensorFlow.</Description> | |||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||||
| <Copyright>Apache2</Copyright> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -16,6 +26,10 @@ | |||||
| </Content> | </Content> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | |||||
| <PackageReference Include="Protobuf.Text" Version="0.3.1" /> | |||||
| </ItemGroup> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <ProjectReference Include="..\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -70,5 +70,27 @@ namespace TensorFlowNET.UnitTest | |||||
| // until queue has more element. | // until queue has more element. | ||||
| } | } | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void PriorityQueue() | |||||
| { | |||||
| var queue = tf.PriorityQueue(3, tf.@string); | |||||
| var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | |||||
| var x = queue.dequeue(); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| init.run(); | |||||
| var result = sess.run(x); | |||||
| Assert.AreEqual(result[0].GetInt64(), 2L); | |||||
| result = sess.run(x); | |||||
| Assert.AreEqual(result[0].GetInt64(), 3L); | |||||
| result = sess.run(x); | |||||
| Assert.AreEqual(result[0].GetInt64(), 4L); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||