| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Estimators; | using Tensorflow.Estimators; | ||||
| using Tensorflow.Data; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -35,7 +36,7 @@ namespace Tensorflow | |||||
| public void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) | public void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) | ||||
| => Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | => Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | ||||
| public TrainSpec TrainSpec(Action input_fn, int max_steps) | |||||
| public TrainSpec TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps) | |||||
| => new TrainSpec(input_fn: input_fn, max_steps: max_steps); | => new TrainSpec(input_fn: input_fn, max_steps: max_steps); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -3,6 +3,8 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Data; | |||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Estimators | namespace Tensorflow.Estimators | ||||
| @@ -30,7 +32,7 @@ namespace Tensorflow.Estimators | |||||
| _model_fn = model_fn; | _model_fn = model_fn; | ||||
| } | } | ||||
| public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null, | |||||
| public Estimator train(Func<DatasetV1Adapter> input_fn, int max_steps = 1, Action[] hooks = null, | |||||
| _NewCheckpointListenerForEvaluate[] saving_listeners = null) | _NewCheckpointListenerForEvaluate[] saving_listeners = null) | ||||
| { | { | ||||
| if(max_steps > 0) | if(max_steps > 0) | ||||
| @@ -56,19 +58,41 @@ namespace Tensorflow.Estimators | |||||
| return cp.AllModelCheckpointPaths.Count - 1; | return cp.AllModelCheckpointPaths.Count - 1; | ||||
| } | } | ||||
| private void _train_model(Action input_fn) | |||||
| private void _train_model(Func<DatasetV1Adapter> input_fn) | |||||
| { | { | ||||
| _train_model_default(input_fn); | _train_model_default(input_fn); | ||||
| } | } | ||||
| private void _train_model_default(Action input_fn) | |||||
| private void _train_model_default(Func<DatasetV1Adapter> input_fn) | |||||
| { | { | ||||
| using (var g = tf.Graph().as_default()) | using (var g = tf.Graph().as_default()) | ||||
| { | { | ||||
| var global_step_tensor = _create_and_assert_global_step(g); | var global_step_tensor = _create_and_assert_global_step(g); | ||||
| // Skip creating a read variable if _create_and_assert_global_step | |||||
| // returns None (e.g. tf.contrib.estimator.SavedModelEstimator). | |||||
| if (global_step_tensor != null) | |||||
| TrainingUtil._get_or_create_global_step_read(g); | |||||
| _get_features_and_labels_from_input_fn(input_fn, "train"); | |||||
| } | } | ||||
| } | } | ||||
| private void _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||||
| { | |||||
| _call_input_fn(input_fn, mode); | |||||
| } | |||||
| /// <summary> | |||||
| /// Calls the input function. | |||||
| /// </summary> | |||||
| /// <param name="input_fn"></param> | |||||
| /// <param name="mode"></param> | |||||
| private void _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||||
| { | |||||
| input_fn(); | |||||
| } | |||||
| private Tensor _create_and_assert_global_step(Graph graph) | private Tensor _create_and_assert_global_step(Graph graph) | ||||
| { | { | ||||
| var step = _create_global_step(graph); | var step = _create_global_step(graph); | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Data; | |||||
| namespace Tensorflow.Estimators | namespace Tensorflow.Estimators | ||||
| { | { | ||||
| @@ -9,10 +10,10 @@ namespace Tensorflow.Estimators | |||||
| int _max_steps; | int _max_steps; | ||||
| public int max_steps => _max_steps; | public int max_steps => _max_steps; | ||||
| Action _input_fn; | |||||
| public Action input_fn => _input_fn; | |||||
| Func<DatasetV1Adapter> _input_fn; | |||||
| public Func<DatasetV1Adapter> input_fn => _input_fn; | |||||
| public TrainSpec(Action input_fn, int max_steps) | |||||
| public TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps) | |||||
| { | { | ||||
| _max_steps = max_steps; | _max_steps = max_steps; | ||||
| _input_fn = input_fn; | _input_fn = input_fn; | ||||
| @@ -434,6 +434,9 @@ namespace Tensorflow | |||||
| case List<RefVariable> list: | case List<RefVariable> list: | ||||
| t = list.Select(x => (T)(object)x).ToList(); | t = list.Select(x => (T)(object)x).ToList(); | ||||
| break; | break; | ||||
| case List<Tensor> list: | |||||
| t = list.Select(x => (T)(object)x).ToList(); | |||||
| break; | |||||
| default: | default: | ||||
| throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); | throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); | ||||
| } | } | ||||
| @@ -7,7 +7,7 @@ namespace Tensorflow.Train | |||||
| { | { | ||||
| public class TrainingUtil | public class TrainingUtil | ||||
| { | { | ||||
| public static RefVariable create_global_step(Graph graph) | |||||
| public static RefVariable create_global_step(Graph graph = null) | |||||
| { | { | ||||
| graph = graph ?? ops.get_default_graph(); | graph = graph ?? ops.get_default_graph(); | ||||
| if (get_global_step(graph) != null) | if (get_global_step(graph) != null) | ||||
| @@ -24,7 +24,7 @@ namespace Tensorflow.Train | |||||
| return v; | return v; | ||||
| } | } | ||||
| public static RefVariable get_global_step(Graph graph) | |||||
| public static RefVariable get_global_step(Graph graph = null) | |||||
| { | { | ||||
| graph = graph ?? ops.get_default_graph(); | graph = graph ?? ops.get_default_graph(); | ||||
| RefVariable global_step_tensor = null; | RefVariable global_step_tensor = null; | ||||
| @@ -47,5 +47,43 @@ namespace Tensorflow.Train | |||||
| return global_step_tensor; | return global_step_tensor; | ||||
| } | } | ||||
| public static Tensor _get_or_create_global_step_read(Graph graph = null) | |||||
| { | |||||
| graph = graph ?? ops.get_default_graph(); | |||||
| var global_step_read_tensor = _get_global_step_read(graph); | |||||
| if (global_step_read_tensor != null) | |||||
| return global_step_read_tensor; | |||||
| var global_step_tensor = get_global_step(graph); | |||||
| if (global_step_tensor == null) | |||||
| return null; | |||||
| var g = graph.as_default(); | |||||
| g.name_scope(null); | |||||
| g.name_scope(global_step_tensor.op.name + "/"); | |||||
| // using initialized_value to ensure that global_step is initialized before | |||||
| // this run. This is needed for example Estimator makes all model_fn build | |||||
| // under global_step_read_tensor dependency. | |||||
| var global_step_value = global_step_tensor.initialized_value(); | |||||
| ops.add_to_collection(tf.GraphKeys.GLOBAL_STEP_READ_KEY, global_step_value + 0); | |||||
| return _get_global_step_read(graph); | |||||
| } | |||||
| private static Tensor _get_global_step_read(Graph graph = null) | |||||
| { | |||||
| graph = graph ?? ops.get_default_graph(); | |||||
| var global_step_read_tensors = graph.get_collection<Tensor>(tf.GraphKeys.GLOBAL_STEP_READ_KEY); | |||||
| if (global_step_read_tensors.Count > 1) | |||||
| throw new RuntimeError($"There are multiple items in collection {tf.GraphKeys.GLOBAL_STEP_READ_KEY}. " + | |||||
| "There should be only one."); | |||||
| if (global_step_read_tensors.Count == 1) | |||||
| return global_step_read_tensors[0]; | |||||
| return null; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -122,6 +122,7 @@ namespace Tensorflow | |||||
| public string TRAIN_OP => TRAIN_OP_; | public string TRAIN_OP => TRAIN_OP_; | ||||
| public string GLOBAL_STEP => GLOBAL_STEP_; | public string GLOBAL_STEP => GLOBAL_STEP_; | ||||
| public string GLOBAL_STEP_READ_KEY = "global_step_read_op_cache"; | |||||
| public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; | public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -0,0 +1,25 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Data; | |||||
| using Tensorflow.Models.ObjectDetection.Protos; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| public class DatasetBuilder | |||||
| { | |||||
| public static DatasetV1Adapter build(InputReader input_reader_config, | |||||
| int batch_size = 0, | |||||
| Action transform_input_data_fn = null) | |||||
| { | |||||
| Func<Dictionary<string, Tensor>, (Dictionary<string, Tensor>, Dictionary<string, Tensor>)> transform_and_pad_input_data_fn = (tensor_dict) => | |||||
| { | |||||
| return (null, null); | |||||
| }; | |||||
| var config = input_reader_config.TfRecordInputReader; | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,40 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Models.ObjectDetection.Protos; | |||||
| using static Tensorflow.Models.ObjectDetection.Protos.ImageResizer; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| public class ImageResizerBuilder | |||||
| { | |||||
| public ImageResizerBuilder() | |||||
| { | |||||
| } | |||||
| public Action build(ImageResizer image_resizer_config) | |||||
| { | |||||
| var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase; | |||||
| if (image_resizer_oneof == ImageResizerOneofOneofCase.KeepAspectRatioResizer) | |||||
| { | |||||
| var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer; | |||||
| var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod); | |||||
| var per_channel_pad_value = new[] { 0, 0, 0 }; | |||||
| //if (keep_aspect_ratio_config.PerChannelPadValue != null) | |||||
| //per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue }; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| return null; | |||||
| } | |||||
| private ResizeType _tf_resize_method(ResizeType resize_method) | |||||
| { | |||||
| return resize_method; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,56 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Models.ObjectDetection.MetaArchitectures; | |||||
| using Tensorflow.Models.ObjectDetection.Protos; | |||||
| using static Tensorflow.Models.ObjectDetection.Protos.DetectionModel; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| public class ModelBuilder | |||||
| { | |||||
| ImageResizerBuilder image_resizer_builder; | |||||
| public ModelBuilder() | |||||
| { | |||||
| image_resizer_builder = new ImageResizerBuilder(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Builds a DetectionModel based on the model config. | |||||
| /// </summary> | |||||
| /// <param name="model_config">A model.proto object containing the config for the desired DetectionModel.</param> | |||||
| /// <param name="is_training">True if this model is being built for training purposes.</param> | |||||
| /// <param name="add_summaries">Whether to add tensorflow summaries in the model graph.</param> | |||||
| /// <returns>DetectionModel based on the config.</returns> | |||||
| public FasterRCNNMetaArch build(DetectionModel model_config, bool is_training, bool add_summaries = true) | |||||
| { | |||||
| var meta_architecture = model_config.ModelCase; | |||||
| if (meta_architecture == ModelOneofCase.Ssd) | |||||
| throw new NotImplementedException(""); | |||||
| else if (meta_architecture == ModelOneofCase.FasterRcnn) | |||||
| return _build_faster_rcnn_model(model_config.FasterRcnn, is_training, add_summaries); | |||||
| throw new ValueError($"Unknown meta architecture: {meta_architecture}"); | |||||
| } | |||||
| /// <summary> | |||||
| /// Builds a Faster R-CNN or R-FCN detection model based on the model config. | |||||
| /// </summary> | |||||
| /// <param name="frcnn_config"></param> | |||||
| /// <param name="is_training"></param> | |||||
| /// <param name="add_summaries"></param> | |||||
| /// <returns>FasterRCNNMetaArch based on the config.</returns> | |||||
| private FasterRCNNMetaArch _build_faster_rcnn_model(FasterRcnn frcnn_config, bool is_training, bool add_summaries) | |||||
| { | |||||
| var num_classes = frcnn_config.NumClasses; | |||||
| var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public Action preprocess() | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Data; | |||||
| using Tensorflow.Estimators; | using Tensorflow.Estimators; | ||||
| namespace Tensorflow.Models.ObjectDetection | namespace Tensorflow.Models.ObjectDetection | ||||
| @@ -8,7 +9,7 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| public class TrainAndEvalDict | public class TrainAndEvalDict | ||||
| { | { | ||||
| public Estimator estimator { get; set; } | public Estimator estimator { get; set; } | ||||
| public Action train_input_fn { get; set; } | |||||
| public Func<DatasetV1Adapter> train_input_fn { get; set; } | |||||
| public Action[] eval_input_fns { get; set; } | public Action[] eval_input_fns { get; set; } | ||||
| public string[] eval_input_names { get; set; } | public string[] eval_input_names { get; set; } | ||||
| public Action eval_on_train_input_fn { get; set; } | public Action eval_on_train_input_fn { get; set; } | ||||
| @@ -0,0 +1,50 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Data; | |||||
| using Tensorflow.Estimators; | |||||
| using Tensorflow.Models.ObjectDetection.MetaArchitectures; | |||||
| using Tensorflow.Models.ObjectDetection.Protos; | |||||
| namespace Tensorflow.Models.ObjectDetection | |||||
| { | |||||
| public class Inputs | |||||
| { | |||||
| ModelBuilder modelBuilder; | |||||
| Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>> INPUT_BUILDER_UTIL_MAP; | |||||
| public Inputs() | |||||
| { | |||||
| modelBuilder = new ModelBuilder(); | |||||
| INPUT_BUILDER_UTIL_MAP = new Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>>(); | |||||
| INPUT_BUILDER_UTIL_MAP["model_build"] = modelBuilder.build; | |||||
| } | |||||
| public Func<DatasetV1Adapter> create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) | |||||
| { | |||||
| Func<DatasetV1Adapter> _train_input_fn = () => | |||||
| { | |||||
| return train_input(train_config, train_input_config, model_config); | |||||
| }; | |||||
| return _train_input_fn; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns `features` and `labels` tensor dictionaries for training. | |||||
| /// </summary> | |||||
| /// <param name="train_config"></param> | |||||
| /// <param name="train_input_config"></param> | |||||
| /// <param name="model_config"></param> | |||||
| /// <returns></returns> | |||||
| public DatasetV1Adapter train_input(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) | |||||
| { | |||||
| var arch = INPUT_BUILDER_UTIL_MAP["model_build"](model_config, true, true); | |||||
| Func<Tensor, (Tensor, Tensor)> model_preprocess_fn = arch.preprocess; | |||||
| var dataset = DatasetBuilder.build(train_input_config); | |||||
| return dataset; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -6,5 +6,9 @@ namespace Tensorflow.Models.ObjectDetection.MetaArchitectures | |||||
| { | { | ||||
| public class FasterRCNNMetaArch | public class FasterRCNNMetaArch | ||||
| { | { | ||||
| public (Tensor, Tensor) preprocess(Tensor tensor) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -6,11 +6,14 @@ using Tensorflow.Estimators; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Contrib.Train; | using Tensorflow.Contrib.Train; | ||||
| using Tensorflow.Models.ObjectDetection.Utils; | using Tensorflow.Models.ObjectDetection.Utils; | ||||
| using Tensorflow.Data; | |||||
| namespace Tensorflow.Models.ObjectDetection | namespace Tensorflow.Models.ObjectDetection | ||||
| { | { | ||||
| public class ModelLib | public class ModelLib | ||||
| { | { | ||||
| Inputs inputs = new Inputs(); | |||||
| public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config, | public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config, | ||||
| HParams hparams = null, | HParams hparams = null, | ||||
| string pipeline_config_path = null, | string pipeline_config_path = null, | ||||
| @@ -21,7 +24,7 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path); | var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path); | ||||
| // Create the input functions for TRAIN/EVAL/PREDICT. | // Create the input functions for TRAIN/EVAL/PREDICT. | ||||
| Action train_input_fn = () => { }; | |||||
| Func<DatasetV1Adapter> train_input_fn = inputs.create_train_input_fn(config.TrainConfig, config.TrainInputReader, config.Model); | |||||
| var eval_input_configs = config.EvalInputReader; | var eval_input_configs = config.EvalInputReader; | ||||
| @@ -44,7 +47,7 @@ namespace Tensorflow.Models.ObjectDetection | |||||
| }; | }; | ||||
| } | } | ||||
| public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Action train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn, | |||||
| public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Func<DatasetV1Adapter> train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn, | |||||
| Action predict_input_fn, int train_steps, bool eval_on_train_data = false, | Action predict_input_fn, int train_steps, bool eval_on_train_data = false, | ||||
| string final_exporter_name = "Servo", string[] eval_spec_names = null) | string final_exporter_name = "Servo", string[] eval_spec_names = null) | ||||
| { | { | ||||