| @@ -45,8 +45,9 @@ namespace Tensorflow.Estimators | |||||
| } | } | ||||
| } | } | ||||
| _train_model(input_fn); | |||||
| throw new NotImplementedException(""); | |||||
| var loss = _train_model(input_fn); | |||||
| print($"Loss for final step: {loss}."); | |||||
| return this; | |||||
| } | } | ||||
| private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) | private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) | ||||
| @@ -58,12 +59,12 @@ namespace Tensorflow.Estimators | |||||
| return cp.AllModelCheckpointPaths.Count - 1; | return cp.AllModelCheckpointPaths.Count - 1; | ||||
| } | } | ||||
| private void _train_model(Func<DatasetV1Adapter> input_fn) | |||||
| private Tensor _train_model(Func<DatasetV1Adapter> input_fn) | |||||
| { | { | ||||
| _train_model_default(input_fn); | |||||
| return _train_model_default(input_fn); | |||||
| } | } | ||||
| private void _train_model_default(Func<DatasetV1Adapter> input_fn) | |||||
| private Tensor _train_model_default(Func<DatasetV1Adapter> input_fn) | |||||
| { | { | ||||
| using (var g = tf.Graph().as_default()) | using (var g = tf.Graph().as_default()) | ||||
| { | { | ||||
| @@ -74,13 +75,16 @@ namespace Tensorflow.Estimators | |||||
| if (global_step_tensor != null) | if (global_step_tensor != null) | ||||
| TrainingUtil._get_or_create_global_step_read(g); | TrainingUtil._get_or_create_global_step_read(g); | ||||
| _get_features_and_labels_from_input_fn(input_fn, "train"); | |||||
| var (features, labels) = _get_features_and_labels_from_input_fn(input_fn, "train"); | |||||
| } | } | ||||
| throw new NotImplementedException(""); | |||||
| } | } | ||||
| private void _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||||
| private (Dictionary<string, Tensor>, Dictionary<string, Tensor>) _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||||
| { | { | ||||
| _call_input_fn(input_fn, mode); | |||||
| var result = _call_input_fn(input_fn, mode); | |||||
| return EstimatorUtil.parse_input_fn_result(result); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -88,9 +92,9 @@ namespace Tensorflow.Estimators | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="input_fn"></param> | /// <param name="input_fn"></param> | ||||
| /// <param name="mode"></param> | /// <param name="mode"></param> | ||||
| private void _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||||
| private DatasetV1Adapter _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||||
| { | { | ||||
| input_fn(); | |||||
| return input_fn(); | |||||
| } | } | ||||
| private Tensor _create_and_assert_global_step(Graph graph) | private Tensor _create_and_assert_global_step(Graph graph) | ||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Data; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class EstimatorUtil | |||||
| { | |||||
| public static (Dictionary<string, Tensor>, Dictionary<string, Tensor>) parse_input_fn_result(DatasetV1Adapter result) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| public class BoxPredictorBuilder | |||||
| { | |||||
| ConvolutionalBoxPredictor _first_stage_box_predictor; | |||||
| public ConvolutionalBoxPredictor build_convolutional_box_predictor() | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -13,6 +13,11 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Builds callable for image resizing operations. | |||||
| /// </summary> | |||||
| /// <param name="image_resizer_config"></param> | |||||
| /// <returns></returns> | |||||
| public Action build(ImageResizer image_resizer_config) | public Action build(ImageResizer image_resizer_config) | ||||
| { | { | ||||
| var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase; | var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase; | ||||
| @@ -21,8 +26,13 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer; | var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer; | ||||
| var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod); | var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod); | ||||
| var per_channel_pad_value = new[] { 0, 0, 0 }; | var per_channel_pad_value = new[] { 0, 0, 0 }; | ||||
| //if (keep_aspect_ratio_config.PerChannelPadValue != null) | |||||
| //per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue }; | |||||
| if (keep_aspect_ratio_config.PerChannelPadValue.Count > 0) | |||||
| throw new NotImplementedException(""); | |||||
| // per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue. }; | |||||
| return () => | |||||
| { | |||||
| }; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -1,7 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Models.ObjectDetection.MetaArchitectures; | |||||
| using Tensorflow.Models.ObjectDetection.Protos; | using Tensorflow.Models.ObjectDetection.Protos; | ||||
| using static Tensorflow.Models.ObjectDetection.Protos.DetectionModel; | using static Tensorflow.Models.ObjectDetection.Protos.DetectionModel; | ||||
| @@ -45,7 +44,20 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| { | { | ||||
| var num_classes = frcnn_config.NumClasses; | var num_classes = frcnn_config.NumClasses; | ||||
| var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer); | var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer); | ||||
| throw new NotImplementedException(""); | |||||
| var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | |||||
| var number_of_stages = frcnn_config.NumberOfStages; | |||||
| return new FasterRCNNMetaArch(new FasterRCNNInitArgs | |||||
| { | |||||
| is_training = is_training, | |||||
| num_classes = num_classes, | |||||
| image_resizer_fn = image_resizer_fn, | |||||
| feature_extractor = () => { throw new NotImplementedException(""); }, | |||||
| number_of_stage = number_of_stages, | |||||
| first_stage_anchor_generator = null, | |||||
| first_stage_atrous_rate = first_stage_atrous_rate | |||||
| }); | |||||
| } | } | ||||
| public Action preprocess() | public Action preprocess() | ||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Models.ObjectDetection.Core | |||||
| { | |||||
| public abstract class DetectionModel | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -2,8 +2,6 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Data; | using Tensorflow.Data; | ||||
| using Tensorflow.Estimators; | |||||
| using Tensorflow.Models.ObjectDetection.MetaArchitectures; | |||||
| using Tensorflow.Models.ObjectDetection.Protos; | using Tensorflow.Models.ObjectDetection.Protos; | ||||
| namespace Tensorflow.Models.ObjectDetection | namespace Tensorflow.Models.ObjectDetection | ||||
| @@ -23,9 +21,7 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| public Func<DatasetV1Adapter> create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) | public Func<DatasetV1Adapter> create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) | ||||
| { | { | ||||
| Func<DatasetV1Adapter> _train_input_fn = () => | Func<DatasetV1Adapter> _train_input_fn = () => | ||||
| { | |||||
| return train_input(train_config, train_input_config, model_config); | |||||
| }; | |||||
| train_input(train_config, train_input_config, model_config); | |||||
| return _train_input_fn; | return _train_input_fn; | ||||
| } | } | ||||
| @@ -0,0 +1,18 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| public class FasterRCNNInitArgs | |||||
| { | |||||
| public bool is_training { get; set; } | |||||
| public int num_classes { get; set; } | |||||
| public Action image_resizer_fn { get; set; } | |||||
| public Action feature_extractor { get; set; } | |||||
| public int number_of_stage { get; set; } | |||||
| public object first_stage_anchor_generator { get; set; } | |||||
| public object first_stage_target_assigner { get; set; } | |||||
| public int first_stage_atrous_rate { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -2,10 +2,17 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Models.ObjectDetection.MetaArchitectures | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | { | ||||
| public class FasterRCNNMetaArch | |||||
| public class FasterRCNNMetaArch : Core.DetectionModel | |||||
| { | { | ||||
| FasterRCNNInitArgs _args; | |||||
| public FasterRCNNMetaArch(FasterRCNNInitArgs args) | |||||
| { | |||||
| _args = args; | |||||
| } | |||||
| public (Tensor, Tensor) preprocess(Tensor tensor) | public (Tensor, Tensor) preprocess(Tensor tensor) | ||||
| { | { | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| public class ConvolutionalBoxPredictor | |||||
| { | |||||
| } | |||||
| } | |||||