From 41432600c8263cf972b13930c31479ffb412fa65 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 14 Sep 2019 15:15:42 -0500 Subject: [PATCH] input_fn #387 --- src/TensorFlowNET.Core/APIs/tf.estimator.cs | 3 +- .../Estimators/Estimator.cs | 30 +++++++++- .../Estimators/TrainSpec.cs | 7 ++- src/TensorFlowNET.Core/Graphs/Graph.cs | 3 + src/TensorFlowNET.Core/Train/TrainingUtil.cs | 42 +++++++++++++- src/TensorFlowNET.Core/ops.GraphKeys.cs | 1 + .../Builders/DatasetBuilder.cs | 25 +++++++++ .../Builders/ImageResizerBuilder.cs | 40 +++++++++++++ .../ObjectDetection/Builders/ModelBuilder.cs | 56 +++++++++++++++++++ .../Entities/TrainAndEvalDict.cs | 3 +- .../ObjectDetection/Inputs.cs | 50 +++++++++++++++++ .../MetaArchitectures/FasterRCNNMetaArch.cs | 4 ++ .../ObjectDetection/ModelLib.cs | 7 ++- 13 files changed, 259 insertions(+), 12 deletions(-) create mode 100644 src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs create mode 100644 src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs create mode 100644 src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs create mode 100644 src/TensorFlowNET.Models/ObjectDetection/Inputs.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.estimator.cs b/src/TensorFlowNET.Core/APIs/tf.estimator.cs index 9789e11f..3cabfdf4 100644 --- a/src/TensorFlowNET.Core/APIs/tf.estimator.cs +++ b/src/TensorFlowNET.Core/APIs/tf.estimator.cs @@ -17,6 +17,7 @@ using System; using static Tensorflow.Binding; using Tensorflow.Estimators; +using Tensorflow.Data; namespace Tensorflow { @@ -35,7 +36,7 @@ namespace Tensorflow public void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) => Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); - public TrainSpec TrainSpec(Action input_fn, int max_steps) + public TrainSpec TrainSpec(Func input_fn, int max_steps) => new TrainSpec(input_fn: input_fn, max_steps: max_steps); /// diff --git a/src/TensorFlowNET.Core/Estimators/Estimator.cs b/src/TensorFlowNET.Core/Estimators/Estimator.cs index 2570206f..0df59dcc 100644 --- a/src/TensorFlowNET.Core/Estimators/Estimator.cs +++ b/src/TensorFlowNET.Core/Estimators/Estimator.cs @@ -3,6 +3,8 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Text; +using Tensorflow.Data; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow.Estimators @@ -30,7 +32,7 @@ namespace Tensorflow.Estimators _model_fn = model_fn; } - public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null, + public Estimator train(Func input_fn, int max_steps = 1, Action[] hooks = null, _NewCheckpointListenerForEvaluate[] saving_listeners = null) { if(max_steps > 0) @@ -56,19 +58,41 @@ namespace Tensorflow.Estimators return cp.AllModelCheckpointPaths.Count - 1; } - private void _train_model(Action input_fn) + private void _train_model(Func input_fn) { _train_model_default(input_fn); } - private void _train_model_default(Action input_fn) + private void _train_model_default(Func input_fn) { using (var g = tf.Graph().as_default()) { var global_step_tensor = _create_and_assert_global_step(g); + + // Skip creating a read variable if _create_and_assert_global_step + // returns None (e.g. tf.contrib.estimator.SavedModelEstimator). + if (global_step_tensor != null) + TrainingUtil._get_or_create_global_step_read(g); + + _get_features_and_labels_from_input_fn(input_fn, "train"); } } + private void _get_features_and_labels_from_input_fn(Func input_fn, string mode) + { + _call_input_fn(input_fn, mode); + } + + /// + /// Calls the input function. + /// + /// + /// + private void _call_input_fn(Func input_fn, string mode) + { + input_fn(); + } + private Tensor _create_and_assert_global_step(Graph graph) { var step = _create_global_step(graph); diff --git a/src/TensorFlowNET.Core/Estimators/TrainSpec.cs b/src/TensorFlowNET.Core/Estimators/TrainSpec.cs index 64b3a829..c2993684 100644 --- a/src/TensorFlowNET.Core/Estimators/TrainSpec.cs +++ b/src/TensorFlowNET.Core/Estimators/TrainSpec.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Data; namespace Tensorflow.Estimators { @@ -9,10 +10,10 @@ namespace Tensorflow.Estimators int _max_steps; public int max_steps => _max_steps; - Action _input_fn; - public Action input_fn => _input_fn; + Func _input_fn; + public Func input_fn => _input_fn; - public TrainSpec(Action input_fn, int max_steps) + public TrainSpec(Func input_fn, int max_steps) { _max_steps = max_steps; _input_fn = input_fn; diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 013365d2..4063453c 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -434,6 +434,9 @@ namespace Tensorflow case List list: t = list.Select(x => (T)(object)x).ToList(); break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; default: throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); } diff --git a/src/TensorFlowNET.Core/Train/TrainingUtil.cs b/src/TensorFlowNET.Core/Train/TrainingUtil.cs index 1b6f7f81..63227733 100644 --- a/src/TensorFlowNET.Core/Train/TrainingUtil.cs +++ b/src/TensorFlowNET.Core/Train/TrainingUtil.cs @@ -7,7 +7,7 @@ namespace Tensorflow.Train { public class TrainingUtil { - public static RefVariable create_global_step(Graph graph) + public static RefVariable create_global_step(Graph graph = null) { graph = graph ?? ops.get_default_graph(); if (get_global_step(graph) != null) @@ -24,7 +24,7 @@ namespace Tensorflow.Train return v; } - public static RefVariable get_global_step(Graph graph) + public static RefVariable get_global_step(Graph graph = null) { graph = graph ?? ops.get_default_graph(); RefVariable global_step_tensor = null; @@ -47,5 +47,43 @@ namespace Tensorflow.Train return global_step_tensor; } + + public static Tensor _get_or_create_global_step_read(Graph graph = null) + { + graph = graph ?? ops.get_default_graph(); + var global_step_read_tensor = _get_global_step_read(graph); + if (global_step_read_tensor != null) + return global_step_read_tensor; + + var global_step_tensor = get_global_step(graph); + + if (global_step_tensor == null) + return null; + + var g = graph.as_default(); + g.name_scope(null); + g.name_scope(global_step_tensor.op.name + "/"); + // using initialized_value to ensure that global_step is initialized before + // this run. This is needed for example Estimator makes all model_fn build + // under global_step_read_tensor dependency. + var global_step_value = global_step_tensor.initialized_value(); + ops.add_to_collection(tf.GraphKeys.GLOBAL_STEP_READ_KEY, global_step_value + 0); + + return _get_global_step_read(graph); + } + + private static Tensor _get_global_step_read(Graph graph = null) + { + graph = graph ?? ops.get_default_graph(); + var global_step_read_tensors = graph.get_collection(tf.GraphKeys.GLOBAL_STEP_READ_KEY); + if (global_step_read_tensors.Count > 1) + throw new RuntimeError($"There are multiple items in collection {tf.GraphKeys.GLOBAL_STEP_READ_KEY}. " + + "There should be only one."); + + if (global_step_read_tensors.Count == 1) + return global_step_read_tensors[0]; + + return null; + } } } diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index 453b9d43..4e7235bc 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -122,6 +122,7 @@ namespace Tensorflow public string TRAIN_OP => TRAIN_OP_; public string GLOBAL_STEP => GLOBAL_STEP_; + public string GLOBAL_STEP_READ_KEY = "global_step_read_op_cache"; public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; /// diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs new file mode 100644 index 00000000..ded41f0e --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Data; +using Tensorflow.Models.ObjectDetection.Protos; + +namespace Tensorflow.Models.ObjectDetection +{ + public class DatasetBuilder + { + public static DatasetV1Adapter build(InputReader input_reader_config, + int batch_size = 0, + Action transform_input_data_fn = null) + { + Func, (Dictionary, Dictionary)> transform_and_pad_input_data_fn = (tensor_dict) => + { + return (null, null); + }; + + var config = input_reader_config.TfRecordInputReader; + + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs new file mode 100644 index 00000000..eab1e8a5 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Models.ObjectDetection.Protos; +using static Tensorflow.Models.ObjectDetection.Protos.ImageResizer; + +namespace Tensorflow.Models.ObjectDetection +{ + public class ImageResizerBuilder + { + public ImageResizerBuilder() + { + + } + + public Action build(ImageResizer image_resizer_config) + { + var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase; + if (image_resizer_oneof == ImageResizerOneofOneofCase.KeepAspectRatioResizer) + { + var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer; + var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod); + var per_channel_pad_value = new[] { 0, 0, 0 }; + //if (keep_aspect_ratio_config.PerChannelPadValue != null) + //per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue }; + } + else + { + throw new NotImplementedException(""); + } + + return null; + } + + private ResizeType _tf_resize_method(ResizeType resize_method) + { + return resize_method; + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs new file mode 100644 index 00000000..1493352a --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs @@ -0,0 +1,56 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Models.ObjectDetection.MetaArchitectures; +using Tensorflow.Models.ObjectDetection.Protos; +using static Tensorflow.Models.ObjectDetection.Protos.DetectionModel; + +namespace Tensorflow.Models.ObjectDetection +{ + public class ModelBuilder + { + ImageResizerBuilder image_resizer_builder; + + public ModelBuilder() + { + image_resizer_builder = new ImageResizerBuilder(); + } + + /// + /// Builds a DetectionModel based on the model config. + /// + /// A model.proto object containing the config for the desired DetectionModel. + /// True if this model is being built for training purposes. + /// Whether to add tensorflow summaries in the model graph. + /// DetectionModel based on the config. + public FasterRCNNMetaArch build(DetectionModel model_config, bool is_training, bool add_summaries = true) + { + var meta_architecture = model_config.ModelCase; + if (meta_architecture == ModelOneofCase.Ssd) + throw new NotImplementedException(""); + else if (meta_architecture == ModelOneofCase.FasterRcnn) + return _build_faster_rcnn_model(model_config.FasterRcnn, is_training, add_summaries); + + throw new ValueError($"Unknown meta architecture: {meta_architecture}"); + } + + /// + /// Builds a Faster R-CNN or R-FCN detection model based on the model config. + /// + /// + /// + /// + /// FasterRCNNMetaArch based on the config. + private FasterRCNNMetaArch _build_faster_rcnn_model(FasterRcnn frcnn_config, bool is_training, bool add_summaries) + { + var num_classes = frcnn_config.NumClasses; + var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer); + throw new NotImplementedException(""); + } + + public Action preprocess() + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs b/src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs index 2f519c0c..aa6bb502 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Data; using Tensorflow.Estimators; namespace Tensorflow.Models.ObjectDetection @@ -8,7 +9,7 @@ namespace Tensorflow.Models.ObjectDetection public class TrainAndEvalDict { public Estimator estimator { get; set; } - public Action train_input_fn { get; set; } + public Func train_input_fn { get; set; } public Action[] eval_input_fns { get; set; } public string[] eval_input_names { get; set; } public Action eval_on_train_input_fn { get; set; } diff --git a/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs b/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs new file mode 100644 index 00000000..eb300752 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs @@ -0,0 +1,50 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Data; +using Tensorflow.Estimators; +using Tensorflow.Models.ObjectDetection.MetaArchitectures; +using Tensorflow.Models.ObjectDetection.Protos; + +namespace Tensorflow.Models.ObjectDetection +{ + public class Inputs + { + ModelBuilder modelBuilder; + Dictionary> INPUT_BUILDER_UTIL_MAP; + + public Inputs() + { + modelBuilder = new ModelBuilder(); + INPUT_BUILDER_UTIL_MAP = new Dictionary>(); + INPUT_BUILDER_UTIL_MAP["model_build"] = modelBuilder.build; + } + + public Func create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) + { + Func _train_input_fn = () => + { + return train_input(train_config, train_input_config, model_config); + }; + + return _train_input_fn; + } + + /// + /// Returns `features` and `labels` tensor dictionaries for training. + /// + /// + /// + /// + /// + public DatasetV1Adapter train_input(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) + { + var arch = INPUT_BUILDER_UTIL_MAP["model_build"](model_config, true, true); + Func model_preprocess_fn = arch.preprocess; + + var dataset = DatasetBuilder.build(train_input_config); + + return dataset; + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs index 2b501a46..5a4c4e5a 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs @@ -6,5 +6,9 @@ namespace Tensorflow.Models.ObjectDetection.MetaArchitectures { public class FasterRCNNMetaArch { + public (Tensor, Tensor) preprocess(Tensor tensor) + { + throw new NotImplementedException(""); + } } } diff --git a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs index 6759e03e..5611356b 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs @@ -6,11 +6,14 @@ using Tensorflow.Estimators; using System.Linq; using Tensorflow.Contrib.Train; using Tensorflow.Models.ObjectDetection.Utils; +using Tensorflow.Data; namespace Tensorflow.Models.ObjectDetection { public class ModelLib { + Inputs inputs = new Inputs(); + public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config, HParams hparams = null, string pipeline_config_path = null, @@ -21,7 +24,7 @@ namespace Tensorflow.Models.ObjectDetection var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path); // Create the input functions for TRAIN/EVAL/PREDICT. - Action train_input_fn = () => { }; + Func train_input_fn = inputs.create_train_input_fn(config.TrainConfig, config.TrainInputReader, config.Model); var eval_input_configs = config.EvalInputReader; @@ -44,7 +47,7 @@ namespace Tensorflow.Models.ObjectDetection }; } - public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Action train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn, + public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Func train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn, Action predict_input_fn, int train_steps, bool eval_on_train_data = false, string final_exporter_name = "Servo", string[] eval_spec_names = null) {