diff --git a/src/TensorFlowNET.Core/Estimators/Estimator.cs b/src/TensorFlowNET.Core/Estimators/Estimator.cs index 0df59dcc..5ba7a9c3 100644 --- a/src/TensorFlowNET.Core/Estimators/Estimator.cs +++ b/src/TensorFlowNET.Core/Estimators/Estimator.cs @@ -45,8 +45,9 @@ namespace Tensorflow.Estimators } } - _train_model(input_fn); - throw new NotImplementedException(""); + var loss = _train_model(input_fn); + print($"Loss for final step: {loss}."); + return this; } private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) @@ -58,12 +59,12 @@ namespace Tensorflow.Estimators return cp.AllModelCheckpointPaths.Count - 1; } - private void _train_model(Func input_fn) + private Tensor _train_model(Func input_fn) { - _train_model_default(input_fn); + return _train_model_default(input_fn); } - private void _train_model_default(Func input_fn) + private Tensor _train_model_default(Func input_fn) { using (var g = tf.Graph().as_default()) { @@ -74,13 +75,16 @@ namespace Tensorflow.Estimators if (global_step_tensor != null) TrainingUtil._get_or_create_global_step_read(g); - _get_features_and_labels_from_input_fn(input_fn, "train"); + var (features, labels) = _get_features_and_labels_from_input_fn(input_fn, "train"); } + + throw new NotImplementedException(""); } - private void _get_features_and_labels_from_input_fn(Func input_fn, string mode) + private (Dictionary, Dictionary) _get_features_and_labels_from_input_fn(Func input_fn, string mode) { - _call_input_fn(input_fn, mode); + var result = _call_input_fn(input_fn, mode); + return EstimatorUtil.parse_input_fn_result(result); } /// @@ -88,9 +92,9 @@ namespace Tensorflow.Estimators /// /// /// - private void _call_input_fn(Func input_fn, string mode) + private DatasetV1Adapter _call_input_fn(Func input_fn, string mode) { - input_fn(); + return input_fn(); } private Tensor _create_and_assert_global_step(Graph graph) diff --git a/src/TensorFlowNET.Core/Estimators/EstimatorUtil.cs b/src/TensorFlowNET.Core/Estimators/EstimatorUtil.cs new file mode 100644 index 00000000..df1fb38b --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/EstimatorUtil.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Data; + +namespace Tensorflow.Estimators +{ + public class EstimatorUtil + { + public static (Dictionary, Dictionary) parse_input_fn_result(DatasetV1Adapter result) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/BoxPredictorBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/BoxPredictorBuilder.cs new file mode 100644 index 00000000..7ff2be25 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/BoxPredictorBuilder.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection +{ + public class BoxPredictorBuilder + { + ConvolutionalBoxPredictor _first_stage_box_predictor; + public ConvolutionalBoxPredictor build_convolutional_box_predictor() + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs index eab1e8a5..81c169b3 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs @@ -13,6 +13,11 @@ namespace Tensorflow.Models.ObjectDetection } + /// + /// Builds callable for image resizing operations. + /// + /// + /// public Action build(ImageResizer image_resizer_config) { var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase; @@ -21,8 +26,13 @@ namespace Tensorflow.Models.ObjectDetection var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer; var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod); var per_channel_pad_value = new[] { 0, 0, 0 }; - //if (keep_aspect_ratio_config.PerChannelPadValue != null) - //per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue }; + if (keep_aspect_ratio_config.PerChannelPadValue.Count > 0) + throw new NotImplementedException(""); + // per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue. }; + return () => + { + + }; } else { diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs index 1493352a..0ff80561 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Text; -using Tensorflow.Models.ObjectDetection.MetaArchitectures; using Tensorflow.Models.ObjectDetection.Protos; using static Tensorflow.Models.ObjectDetection.Protos.DetectionModel; @@ -45,7 +44,20 @@ namespace Tensorflow.Models.ObjectDetection { var num_classes = frcnn_config.NumClasses; var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer); - throw new NotImplementedException(""); + + var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; + var number_of_stages = frcnn_config.NumberOfStages; + + return new FasterRCNNMetaArch(new FasterRCNNInitArgs + { + is_training = is_training, + num_classes = num_classes, + image_resizer_fn = image_resizer_fn, + feature_extractor = () => { throw new NotImplementedException(""); }, + number_of_stage = number_of_stages, + first_stage_anchor_generator = null, + first_stage_atrous_rate = first_stage_atrous_rate + }); } public Action preprocess() diff --git a/src/TensorFlowNET.Models/ObjectDetection/Core/DetectionModel.cs b/src/TensorFlowNET.Models/ObjectDetection/Core/DetectionModel.cs new file mode 100644 index 00000000..24578a5b --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Core/DetectionModel.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection.Core +{ + public abstract class DetectionModel + { + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs b/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs index eb300752..0388b78a 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs @@ -2,8 +2,6 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Data; -using Tensorflow.Estimators; -using Tensorflow.Models.ObjectDetection.MetaArchitectures; using Tensorflow.Models.ObjectDetection.Protos; namespace Tensorflow.Models.ObjectDetection @@ -23,9 +21,7 @@ namespace Tensorflow.Models.ObjectDetection public Func create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) { Func _train_input_fn = () => - { - return train_input(train_config, train_input_config, model_config); - }; + train_input(train_config, train_input_config, model_config); return _train_input_fn; } diff --git a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs new file mode 100644 index 00000000..991ffff4 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection +{ + public class FasterRCNNInitArgs + { + public bool is_training { get; set; } + public int num_classes { get; set; } + public Action image_resizer_fn { get; set; } + public Action feature_extractor { get; set; } + public int number_of_stage { get; set; } + public object first_stage_anchor_generator { get; set; } + public object first_stage_target_assigner { get; set; } + public int first_stage_atrous_rate { get; set; } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs index 5a4c4e5a..083d4e4f 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs @@ -2,10 +2,17 @@ using System.Collections.Generic; using System.Text; -namespace Tensorflow.Models.ObjectDetection.MetaArchitectures +namespace Tensorflow.Models.ObjectDetection { - public class FasterRCNNMetaArch + public class FasterRCNNMetaArch : Core.DetectionModel { + FasterRCNNInitArgs _args; + + public FasterRCNNMetaArch(FasterRCNNInitArgs args) + { + _args = args; + } + public (Tensor, Tensor) preprocess(Tensor tensor) { throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Models/ObjectDetection/Predictors/ConvolutionalBoxPredictor.cs b/src/TensorFlowNET.Models/ObjectDetection/Predictors/ConvolutionalBoxPredictor.cs new file mode 100644 index 00000000..bd2f4114 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Predictors/ConvolutionalBoxPredictor.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection +{ + public class ConvolutionalBoxPredictor + { + } +}