| @@ -57,6 +57,12 @@ namespace Tensorflow | |||
| clear_devices: clear_devices, | |||
| clear_extraneous_savers: clear_extraneous_savers, | |||
| strip_default_attrs: strip_default_attrs); | |||
| public string latest_checkpoint(string checkpoint_dir, string latest_filename = null) | |||
| => checkpoint_management.latest_checkpoint(checkpoint_dir, latest_filename: latest_filename); | |||
| public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) | |||
| => checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); | |||
| } | |||
| } | |||
| } | |||
| @@ -18,20 +18,35 @@ namespace Tensorflow.Estimators | |||
| string _model_dir; | |||
| Action _model_fn; | |||
| public Estimator(Action model_fn, RunConfig config) | |||
| { | |||
| _config = config; | |||
| _model_dir = _config.model_dir; | |||
| _session_config = _config.session_config; | |||
| _model_fn = model_fn; | |||
| } | |||
| public Estimator train(Action input_fn, int max_steps = 1, | |||
| public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null, | |||
| _NewCheckpointListenerForEvaluate[] saving_listeners = null) | |||
| { | |||
| if(max_steps > 0) | |||
| { | |||
| var start_step = _load_global_step_from_checkpoint_dir(_model_dir); | |||
| } | |||
| _train_model(); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) | |||
| { | |||
| var cp = tf.train.latest_checkpoint(checkpoint_dir); | |||
| return 0; | |||
| } | |||
| private void _train_model() | |||
| { | |||
| _train_model_default(); | |||
| @@ -6,9 +6,11 @@ namespace Tensorflow.Estimators | |||
| { | |||
| public class EvalSpec | |||
| { | |||
| string _name; | |||
| public EvalSpec(string name, Action input_fn, FinalExporter exporters) | |||
| { | |||
| _name = name; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,11 +6,16 @@ namespace Tensorflow.Estimators | |||
| { | |||
| public class TrainSpec | |||
| { | |||
| public int max_steps { get; set; } | |||
| int _max_steps; | |||
| public int max_steps => _max_steps; | |||
| Action _input_fn; | |||
| public Action input_fn => _input_fn; | |||
| public TrainSpec(Action input_fn, int max_steps) | |||
| { | |||
| this.max_steps = max_steps; | |||
| _max_steps = max_steps; | |||
| _input_fn = input_fn; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,5 +6,11 @@ namespace Tensorflow.Estimators | |||
| { | |||
| public class _NewCheckpointListenerForEvaluate | |||
| { | |||
| _Evaluator _evaluator; | |||
| public _NewCheckpointListenerForEvaluate(_Evaluator evaluator, int eval_throttle_secs) | |||
| { | |||
| _evaluator = evaluator; | |||
| } | |||
| } | |||
| } | |||
| @@ -32,15 +32,17 @@ namespace Tensorflow.Estimators | |||
| /// </summary> | |||
| private void run_local() | |||
| { | |||
| var train_hooks = new Action[0]; | |||
| Console.WriteLine("Start train and evaluate loop. The evaluate will happen " + | |||
| "after every checkpoint. Checkpoint frequency is determined " + | |||
| $"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " + | |||
| $"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}."); | |||
| var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps); | |||
| /*_estimator.train(input_fn: _train_spec.input_fn, | |||
| var saving_listeners = new _NewCheckpointListenerForEvaluate[0]; | |||
| _estimator.train(input_fn: _train_spec.input_fn, | |||
| max_steps: _train_spec.max_steps, | |||
| hooks: train_hooks, | |||
| saving_listeners: saving_listeners);*/ | |||
| saving_listeners: saving_listeners); | |||
| } | |||
| } | |||
| } | |||
| @@ -19,6 +19,7 @@ using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using static Tensorflow.SaverDef.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -144,5 +145,54 @@ namespace Tensorflow | |||
| return prefix + ".index"; | |||
| return prefix; | |||
| } | |||
| /// <summary> | |||
| /// Finds the filename of latest saved checkpoint file. | |||
| /// </summary> | |||
| /// <param name="checkpoint_dir"></param> | |||
| /// <param name="latest_filename"></param> | |||
| /// <returns></returns> | |||
| public static string latest_checkpoint(string checkpoint_dir, string latest_filename = null) | |||
| { | |||
| // Pick the latest checkpoint based on checkpoint state. | |||
| var ckpt = get_checkpoint_state(checkpoint_dir, latest_filename); | |||
| if(ckpt != null && !string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) | |||
| { | |||
| // Look for either a V2 path or a V1 path, with priority for V2. | |||
| var v2_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V2); | |||
| var v1_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V1); | |||
| if (File.Exists(v2_path) || File.Exists(v1_path)) | |||
| return ckpt.ModelCheckpointPath; | |||
| else | |||
| throw new ValueError($"Couldn't match files for checkpoint {ckpt.ModelCheckpointPath}"); | |||
| } | |||
| return null; | |||
| } | |||
| public static CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) | |||
| { | |||
| var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename); | |||
| if (File.Exists(coord_checkpoint_filename)) | |||
| { | |||
| var file_content = File.ReadAllBytes(coord_checkpoint_filename); | |||
| var ckpt = CheckpointState.Parser.ParseFrom(file_content); | |||
| if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) | |||
| throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}"); | |||
| // For relative model_checkpoint_path and all_model_checkpoint_paths, | |||
| // prepend checkpoint_dir. | |||
| if (!Path.IsPathRooted(ckpt.ModelCheckpointPath)) | |||
| ckpt.ModelCheckpointPath = Path.Combine(checkpoint_dir, ckpt.ModelCheckpointPath); | |||
| foreach(var i in range(len(ckpt.AllModelCheckpointPaths))) | |||
| { | |||
| var p = ckpt.AllModelCheckpointPaths[i]; | |||
| if (!Path.IsPathRooted(p)) | |||
| ckpt.AllModelCheckpointPaths[i] = Path.Combine(checkpoint_dir, p); | |||
| } | |||
| return ckpt; | |||
| } | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| @@ -19,19 +19,28 @@ namespace Tensorflow.Models.ObjectDetection | |||
| int sample_1_of_n_eval_on_train_examples = 1) | |||
| { | |||
| var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path); | |||
| // Create the input functions for TRAIN/EVAL/PREDICT. | |||
| Action train_input_fn = () => { }; | |||
| var eval_input_configs = config.EvalInputReader; | |||
| var eval_input_fns = new Action[eval_input_configs.Count]; | |||
| var eval_input_names = eval_input_configs.Select(eval_input_config => eval_input_config.Name).ToArray(); | |||
| Action eval_on_train_input_fn = () => { }; | |||
| Action predict_input_fn = () => { }; | |||
| Action model_fn = () => { }; | |||
| var estimator = tf.estimator.Estimator(model_fn: model_fn, config: run_config); | |||
| return new TrainAndEvalDict | |||
| { | |||
| estimator = estimator, | |||
| train_steps = train_steps, | |||
| train_input_fn = train_input_fn, | |||
| eval_input_fns = eval_input_fns, | |||
| eval_input_names = eval_input_names | |||
| eval_input_names = eval_input_names, | |||
| eval_on_train_input_fn = eval_on_train_input_fn, | |||
| predict_input_fn = predict_input_fn, | |||
| train_steps = train_steps | |||
| }; | |||
| } | |||
| @@ -46,10 +55,7 @@ namespace Tensorflow.Models.ObjectDetection | |||
| .Select(x => x.ToString()) | |||
| .ToArray(); | |||
| var eval_specs = new List<EvalSpec>() | |||
| { | |||
| new EvalSpec("", null, null) // for test. | |||
| }; | |||
| var eval_specs = new List<EvalSpec>(); | |||
| foreach (var (index, (eval_spec_name, eval_input_fn)) in enumerate(zip(eval_spec_names, eval_input_fns).ToList())) | |||
| { | |||
| var exporter_name = index == 0 ? final_exporter_name : $"{final_exporter_name}_{eval_spec_name}"; | |||
| @@ -21,7 +21,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.ObjectDetection | |||
| string model_dir = "D:/Projects/PythonLab/tf-models/research/object_detection/models/model"; | |||
| string pipeline_config_path = "ObjectDetection/Models/faster_rcnn_resnet101_voc07.config"; | |||
| int num_train_steps = 1; | |||
| int num_train_steps = 50; | |||
| int sample_1_of_n_eval_examples = 1; | |||
| int sample_1_of_n_eval_on_train_examples = 5; | |||