| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| public class GridAnchorGenerator : Core.AnchorGenerator | |||||
| { | |||||
| public GridAnchorGenerator(float[] scales = null) | |||||
| { | |||||
| if (scales == null) | |||||
| scales = new[] { 0.5f, 1.0f, 2.0f }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,27 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Models.ObjectDetection.Protos; | |||||
| using static Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| public class AnchorGeneratorBuilder | |||||
| { | |||||
| public AnchorGeneratorBuilder() | |||||
| { | |||||
| } | |||||
| public GridAnchorGenerator build(AnchorGenerator anchor_generator_config) | |||||
| { | |||||
| if(anchor_generator_config.AnchorGeneratorOneofCase == AnchorGeneratorOneofOneofCase.GridAnchorGenerator) | |||||
| { | |||||
| var grid_anchor_generator_config = anchor_generator_config.GridAnchorGenerator; | |||||
| return new GridAnchorGenerator(scales: grid_anchor_generator_config.Scales.Select(x => float.Parse(x.ToString())).ToArray()); | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -10,6 +10,7 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| { | { | ||||
| ImageResizerBuilder _image_resizer_builder; | ImageResizerBuilder _image_resizer_builder; | ||||
| FasterRCNNFeatureExtractor _feature_extractor; | FasterRCNNFeatureExtractor _feature_extractor; | ||||
| AnchorGeneratorBuilder anchor_generator_builder; | |||||
| public ModelBuilder() | public ModelBuilder() | ||||
| { | { | ||||
| @@ -46,8 +47,12 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| var num_classes = frcnn_config.NumClasses; | var num_classes = frcnn_config.NumClasses; | ||||
| var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer); | var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer); | ||||
| var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | |||||
| var feature_extractor = _build_faster_rcnn_feature_extractor(frcnn_config.FeatureExtractor, is_training, | |||||
| inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate); | |||||
| var number_of_stages = frcnn_config.NumberOfStages; | var number_of_stages = frcnn_config.NumberOfStages; | ||||
| var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); | |||||
| var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | |||||
| return new FasterRCNNMetaArch(new FasterRCNNInitArgs | return new FasterRCNNMetaArch(new FasterRCNNInitArgs | ||||
| { | { | ||||
| @@ -65,5 +70,19 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| { | { | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| private FasterRCNNFeatureExtractor _build_faster_rcnn_feature_extractor(FasterRcnnFeatureExtractor feature_extractor_config, | |||||
| bool is_training, bool reuse_weights = false, bool inplace_batchnorm_update = false) | |||||
| { | |||||
| if (inplace_batchnorm_update) | |||||
| throw new ValueError("inplace batchnorm updates not supported."); | |||||
| var feature_type = feature_extractor_config.Type; | |||||
| var first_stage_features_stride = feature_extractor_config.FirstStageFeaturesStride; | |||||
| var batch_norm_trainable = feature_extractor_config.BatchNormTrainable; | |||||
| return new FasterRCNNResnet101FeatureExtractor(is_training, first_stage_features_stride, | |||||
| batch_norm_trainable: batch_norm_trainable, | |||||
| reuse_weights: reuse_weights); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Models.ObjectDetection.Core | |||||
| { | |||||
| public class AnchorGenerator | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -4,7 +4,28 @@ using System.Text; | |||||
| namespace Tensorflow.Models.ObjectDetection | namespace Tensorflow.Models.ObjectDetection | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Faster R-CNN Feature Extractor definition. | |||||
| /// </summary> | |||||
| public class FasterRCNNFeatureExtractor | public class FasterRCNNFeatureExtractor | ||||
| { | { | ||||
| bool _is_training; | |||||
| int _first_stage_features_stride; | |||||
| bool _reuse_weights = false; | |||||
| float _weight_decay = 0.0f; | |||||
| bool _train_batch_norm; | |||||
| public FasterRCNNFeatureExtractor(bool is_training, | |||||
| int first_stage_features_stride, | |||||
| bool batch_norm_trainable = false, | |||||
| bool reuse_weights = false, | |||||
| float weight_decay = 0.0f) | |||||
| { | |||||
| _is_training = is_training; | |||||
| _first_stage_features_stride = first_stage_features_stride; | |||||
| _train_batch_norm = (batch_norm_trainable && is_training); | |||||
| _reuse_weights = reuse_weights; | |||||
| _weight_decay = weight_decay; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,32 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| using Tensorflow.Operations.Activation; | |||||
| using Tensorflow.Models.Slim.Nets; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| /// <summary> | |||||
| /// Faster R-CNN Resnet 101 feature extractor implementation. | |||||
| /// </summary> | |||||
| public class FasterRCNNResnet101FeatureExtractor : FasterRCNNResnetV1FeatureExtractor | |||||
| { | |||||
| public FasterRCNNResnet101FeatureExtractor(bool is_training, | |||||
| int first_stage_features_stride, | |||||
| bool batch_norm_trainable = false, | |||||
| bool reuse_weights = false, | |||||
| float weight_decay = 0.0f, | |||||
| IActivation activation_fn = null) : base("resnet_v1_101", | |||||
| ResNetV1.resnet_v1_101, | |||||
| is_training, | |||||
| first_stage_features_stride, | |||||
| batch_norm_trainable: batch_norm_trainable, | |||||
| reuse_weights: reuse_weights, | |||||
| weight_decay: weight_decay, | |||||
| activation_fn: activation_fn) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,28 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| using Tensorflow.Operations.Activation; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| public class FasterRCNNResnetV1FeatureExtractor : FasterRCNNFeatureExtractor | |||||
| { | |||||
| public FasterRCNNResnetV1FeatureExtractor(string architecture, | |||||
| Action resnet_model, | |||||
| bool is_training, | |||||
| int first_stage_features_stride, | |||||
| bool batch_norm_trainable = false, | |||||
| bool reuse_weights = false, | |||||
| float weight_decay = 0.0f, | |||||
| IActivation activation_fn = null) : base(is_training, | |||||
| first_stage_features_stride, | |||||
| batch_norm_trainable: batch_norm_trainable, | |||||
| reuse_weights: reuse_weights, | |||||
| weight_decay: weight_decay) | |||||
| { | |||||
| if (activation_fn == null) | |||||
| activation_fn = tf.nn.relu(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Models.Slim.Nets | |||||
| { | |||||
| public class ResNetV1 | |||||
| { | |||||
| public static void resnet_v1_101() | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -73,7 +73,6 @@ namespace TensorFlowNET.Examples.ImageProcessing.ObjectDetection | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| } | } | ||||
| public void Train(Session sess) | public void Train(Session sess) | ||||