| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Estimators; | |||
| using Tensorflow.Data; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -35,7 +36,7 @@ namespace Tensorflow | |||
| public void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) | |||
| => Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | |||
| public TrainSpec TrainSpec(Action input_fn, int max_steps) | |||
| public TrainSpec TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps) | |||
| => new TrainSpec(input_fn: input_fn, max_steps: max_steps); | |||
| /// <summary> | |||
| @@ -3,6 +3,8 @@ using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Text; | |||
| using Tensorflow.Data; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Estimators | |||
| @@ -30,7 +32,7 @@ namespace Tensorflow.Estimators | |||
| _model_fn = model_fn; | |||
| } | |||
| public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null, | |||
| public Estimator train(Func<DatasetV1Adapter> input_fn, int max_steps = 1, Action[] hooks = null, | |||
| _NewCheckpointListenerForEvaluate[] saving_listeners = null) | |||
| { | |||
| if(max_steps > 0) | |||
| @@ -56,19 +58,41 @@ namespace Tensorflow.Estimators | |||
| return cp.AllModelCheckpointPaths.Count - 1; | |||
| } | |||
| private void _train_model(Action input_fn) | |||
| private void _train_model(Func<DatasetV1Adapter> input_fn) | |||
| { | |||
| _train_model_default(input_fn); | |||
| } | |||
| private void _train_model_default(Action input_fn) | |||
| private void _train_model_default(Func<DatasetV1Adapter> input_fn) | |||
| { | |||
| using (var g = tf.Graph().as_default()) | |||
| { | |||
| var global_step_tensor = _create_and_assert_global_step(g); | |||
| // Skip creating a read variable if _create_and_assert_global_step | |||
| // returns None (e.g. tf.contrib.estimator.SavedModelEstimator). | |||
| if (global_step_tensor != null) | |||
| TrainingUtil._get_or_create_global_step_read(g); | |||
| _get_features_and_labels_from_input_fn(input_fn, "train"); | |||
| } | |||
| } | |||
| private void _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||
| { | |||
| _call_input_fn(input_fn, mode); | |||
| } | |||
| /// <summary> | |||
| /// Calls the input function. | |||
| /// </summary> | |||
| /// <param name="input_fn"></param> | |||
| /// <param name="mode"></param> | |||
| private void _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||
| { | |||
| input_fn(); | |||
| } | |||
| private Tensor _create_and_assert_global_step(Graph graph) | |||
| { | |||
| var step = _create_global_step(graph); | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Data; | |||
| namespace Tensorflow.Estimators | |||
| { | |||
| @@ -9,10 +10,10 @@ namespace Tensorflow.Estimators | |||
| int _max_steps; | |||
| public int max_steps => _max_steps; | |||
| Action _input_fn; | |||
| public Action input_fn => _input_fn; | |||
| Func<DatasetV1Adapter> _input_fn; | |||
| public Func<DatasetV1Adapter> input_fn => _input_fn; | |||
| public TrainSpec(Action input_fn, int max_steps) | |||
| public TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps) | |||
| { | |||
| _max_steps = max_steps; | |||
| _input_fn = input_fn; | |||
| @@ -434,6 +434,9 @@ namespace Tensorflow | |||
| case List<RefVariable> list: | |||
| t = list.Select(x => (T)(object)x).ToList(); | |||
| break; | |||
| case List<Tensor> list: | |||
| t = list.Select(x => (T)(object)x).ToList(); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); | |||
| } | |||
| @@ -7,7 +7,7 @@ namespace Tensorflow.Train | |||
| { | |||
| public class TrainingUtil | |||
| { | |||
| public static RefVariable create_global_step(Graph graph) | |||
| public static RefVariable create_global_step(Graph graph = null) | |||
| { | |||
| graph = graph ?? ops.get_default_graph(); | |||
| if (get_global_step(graph) != null) | |||
| @@ -24,7 +24,7 @@ namespace Tensorflow.Train | |||
| return v; | |||
| } | |||
| public static RefVariable get_global_step(Graph graph) | |||
| public static RefVariable get_global_step(Graph graph = null) | |||
| { | |||
| graph = graph ?? ops.get_default_graph(); | |||
| RefVariable global_step_tensor = null; | |||
| @@ -47,5 +47,43 @@ namespace Tensorflow.Train | |||
| return global_step_tensor; | |||
| } | |||
| public static Tensor _get_or_create_global_step_read(Graph graph = null) | |||
| { | |||
| graph = graph ?? ops.get_default_graph(); | |||
| var global_step_read_tensor = _get_global_step_read(graph); | |||
| if (global_step_read_tensor != null) | |||
| return global_step_read_tensor; | |||
| var global_step_tensor = get_global_step(graph); | |||
| if (global_step_tensor == null) | |||
| return null; | |||
| var g = graph.as_default(); | |||
| g.name_scope(null); | |||
| g.name_scope(global_step_tensor.op.name + "/"); | |||
| // using initialized_value to ensure that global_step is initialized before | |||
| // this run. This is needed for example Estimator makes all model_fn build | |||
| // under global_step_read_tensor dependency. | |||
| var global_step_value = global_step_tensor.initialized_value(); | |||
| ops.add_to_collection(tf.GraphKeys.GLOBAL_STEP_READ_KEY, global_step_value + 0); | |||
| return _get_global_step_read(graph); | |||
| } | |||
| private static Tensor _get_global_step_read(Graph graph = null) | |||
| { | |||
| graph = graph ?? ops.get_default_graph(); | |||
| var global_step_read_tensors = graph.get_collection<Tensor>(tf.GraphKeys.GLOBAL_STEP_READ_KEY); | |||
| if (global_step_read_tensors.Count > 1) | |||
| throw new RuntimeError($"There are multiple items in collection {tf.GraphKeys.GLOBAL_STEP_READ_KEY}. " + | |||
| "There should be only one."); | |||
| if (global_step_read_tensors.Count == 1) | |||
| return global_step_read_tensors[0]; | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| @@ -122,6 +122,7 @@ namespace Tensorflow | |||
| public string TRAIN_OP => TRAIN_OP_; | |||
| public string GLOBAL_STEP => GLOBAL_STEP_; | |||
| public string GLOBAL_STEP_READ_KEY = "global_step_read_op_cache"; | |||
| public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; | |||
| /// <summary> | |||
| @@ -0,0 +1,25 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Data; | |||
| using Tensorflow.Models.ObjectDetection.Protos; | |||
| namespace Tensorflow.Models.ObjectDetection | |||
| { | |||
| public class DatasetBuilder | |||
| { | |||
| public static DatasetV1Adapter build(InputReader input_reader_config, | |||
| int batch_size = 0, | |||
| Action transform_input_data_fn = null) | |||
| { | |||
| Func<Dictionary<string, Tensor>, (Dictionary<string, Tensor>, Dictionary<string, Tensor>)> transform_and_pad_input_data_fn = (tensor_dict) => | |||
| { | |||
| return (null, null); | |||
| }; | |||
| var config = input_reader_config.TfRecordInputReader; | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Models.ObjectDetection.Protos; | |||
| using static Tensorflow.Models.ObjectDetection.Protos.ImageResizer; | |||
| namespace Tensorflow.Models.ObjectDetection | |||
| { | |||
| public class ImageResizerBuilder | |||
| { | |||
| public ImageResizerBuilder() | |||
| { | |||
| } | |||
| public Action build(ImageResizer image_resizer_config) | |||
| { | |||
| var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase; | |||
| if (image_resizer_oneof == ImageResizerOneofOneofCase.KeepAspectRatioResizer) | |||
| { | |||
| var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer; | |||
| var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod); | |||
| var per_channel_pad_value = new[] { 0, 0, 0 }; | |||
| //if (keep_aspect_ratio_config.PerChannelPadValue != null) | |||
| //per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue }; | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| return null; | |||
| } | |||
| private ResizeType _tf_resize_method(ResizeType resize_method) | |||
| { | |||
| return resize_method; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,56 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Models.ObjectDetection.MetaArchitectures; | |||
| using Tensorflow.Models.ObjectDetection.Protos; | |||
| using static Tensorflow.Models.ObjectDetection.Protos.DetectionModel; | |||
| namespace Tensorflow.Models.ObjectDetection | |||
| { | |||
| public class ModelBuilder | |||
| { | |||
| ImageResizerBuilder image_resizer_builder; | |||
| public ModelBuilder() | |||
| { | |||
| image_resizer_builder = new ImageResizerBuilder(); | |||
| } | |||
| /// <summary> | |||
| /// Builds a DetectionModel based on the model config. | |||
| /// </summary> | |||
| /// <param name="model_config">A model.proto object containing the config for the desired DetectionModel.</param> | |||
| /// <param name="is_training">True if this model is being built for training purposes.</param> | |||
| /// <param name="add_summaries">Whether to add tensorflow summaries in the model graph.</param> | |||
| /// <returns>DetectionModel based on the config.</returns> | |||
| public FasterRCNNMetaArch build(DetectionModel model_config, bool is_training, bool add_summaries = true) | |||
| { | |||
| var meta_architecture = model_config.ModelCase; | |||
| if (meta_architecture == ModelOneofCase.Ssd) | |||
| throw new NotImplementedException(""); | |||
| else if (meta_architecture == ModelOneofCase.FasterRcnn) | |||
| return _build_faster_rcnn_model(model_config.FasterRcnn, is_training, add_summaries); | |||
| throw new ValueError($"Unknown meta architecture: {meta_architecture}"); | |||
| } | |||
| /// <summary> | |||
| /// Builds a Faster R-CNN or R-FCN detection model based on the model config. | |||
| /// </summary> | |||
| /// <param name="frcnn_config"></param> | |||
| /// <param name="is_training"></param> | |||
| /// <param name="add_summaries"></param> | |||
| /// <returns>FasterRCNNMetaArch based on the config.</returns> | |||
| private FasterRCNNMetaArch _build_faster_rcnn_model(FasterRcnn frcnn_config, bool is_training, bool add_summaries) | |||
| { | |||
| var num_classes = frcnn_config.NumClasses; | |||
| var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public Action preprocess() | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Data; | |||
| using Tensorflow.Estimators; | |||
| namespace Tensorflow.Models.ObjectDetection | |||
| @@ -8,7 +9,7 @@ namespace Tensorflow.Models.ObjectDetection | |||
| public class TrainAndEvalDict | |||
| { | |||
| public Estimator estimator { get; set; } | |||
| public Action train_input_fn { get; set; } | |||
| public Func<DatasetV1Adapter> train_input_fn { get; set; } | |||
| public Action[] eval_input_fns { get; set; } | |||
| public string[] eval_input_names { get; set; } | |||
| public Action eval_on_train_input_fn { get; set; } | |||
| @@ -0,0 +1,50 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Data; | |||
| using Tensorflow.Estimators; | |||
| using Tensorflow.Models.ObjectDetection.MetaArchitectures; | |||
| using Tensorflow.Models.ObjectDetection.Protos; | |||
| namespace Tensorflow.Models.ObjectDetection | |||
| { | |||
| public class Inputs | |||
| { | |||
| ModelBuilder modelBuilder; | |||
| Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>> INPUT_BUILDER_UTIL_MAP; | |||
| public Inputs() | |||
| { | |||
| modelBuilder = new ModelBuilder(); | |||
| INPUT_BUILDER_UTIL_MAP = new Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>>(); | |||
| INPUT_BUILDER_UTIL_MAP["model_build"] = modelBuilder.build; | |||
| } | |||
| public Func<DatasetV1Adapter> create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) | |||
| { | |||
| Func<DatasetV1Adapter> _train_input_fn = () => | |||
| { | |||
| return train_input(train_config, train_input_config, model_config); | |||
| }; | |||
| return _train_input_fn; | |||
| } | |||
| /// <summary> | |||
| /// Returns `features` and `labels` tensor dictionaries for training. | |||
| /// </summary> | |||
| /// <param name="train_config"></param> | |||
| /// <param name="train_input_config"></param> | |||
| /// <param name="model_config"></param> | |||
| /// <returns></returns> | |||
| public DatasetV1Adapter train_input(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) | |||
| { | |||
| var arch = INPUT_BUILDER_UTIL_MAP["model_build"](model_config, true, true); | |||
| Func<Tensor, (Tensor, Tensor)> model_preprocess_fn = arch.preprocess; | |||
| var dataset = DatasetBuilder.build(train_input_config); | |||
| return dataset; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,5 +6,9 @@ namespace Tensorflow.Models.ObjectDetection.MetaArchitectures | |||
| { | |||
| public class FasterRCNNMetaArch | |||
| { | |||
| public (Tensor, Tensor) preprocess(Tensor tensor) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -6,11 +6,14 @@ using Tensorflow.Estimators; | |||
| using System.Linq; | |||
| using Tensorflow.Contrib.Train; | |||
| using Tensorflow.Models.ObjectDetection.Utils; | |||
| using Tensorflow.Data; | |||
| namespace Tensorflow.Models.ObjectDetection | |||
| { | |||
| public class ModelLib | |||
| { | |||
| Inputs inputs = new Inputs(); | |||
| public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config, | |||
| HParams hparams = null, | |||
| string pipeline_config_path = null, | |||
| @@ -21,7 +24,7 @@ namespace Tensorflow.Models.ObjectDetection | |||
| var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path); | |||
| // Create the input functions for TRAIN/EVAL/PREDICT. | |||
| Action train_input_fn = () => { }; | |||
| Func<DatasetV1Adapter> train_input_fn = inputs.create_train_input_fn(config.TrainConfig, config.TrainInputReader, config.Model); | |||
| var eval_input_configs = config.EvalInputReader; | |||
| @@ -44,7 +47,7 @@ namespace Tensorflow.Models.ObjectDetection | |||
| }; | |||
| } | |||
| public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Action train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn, | |||
| public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Func<DatasetV1Adapter> train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn, | |||
| Action predict_input_fn, int train_steps, bool eval_on_train_data = false, | |||
| string final_exporter_name = "Servo", string[] eval_spec_names = null) | |||
| { | |||