|
- using System;
- using System.Collections.Generic;
- using System.Text;
- using static Tensorflow.Binding;
- using Tensorflow.Estimators;
- using System.Linq;
- using Tensorflow.Contrib.Train;
- using Tensorflow.Models.ObjectDetection.Utils;
-
- namespace Tensorflow.Models.ObjectDetection
- {
- public class ModelLib
- {
- public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config,
- HParams hparams = null,
- string pipeline_config_path = null,
- int train_steps = 0,
- int sample_1_of_n_eval_examples = 0,
- int sample_1_of_n_eval_on_train_examples = 1)
- {
- var estimator = tf.estimator.Estimator(config: run_config);
-
- var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path);
- var eval_input_configs = config.EvalInputReader;
-
- var eval_input_fns = new Action[eval_input_configs.Count];
-
- return new TrainAndEvalDict
- {
- estimator = estimator,
- train_steps = train_steps,
- eval_input_fns = eval_input_fns
- };
- }
-
- public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Action train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn,
- Action predict_input_fn, int train_steps, bool eval_on_train_data = false,
- string final_exporter_name = "Servo", string[] eval_spec_names = null)
- {
- var train_spec = tf.estimator.TrainSpec(input_fn: train_input_fn, max_steps: train_steps);
-
- if (eval_spec_names == null)
- eval_spec_names = range(len(eval_input_fns))
- .Select(x => x.ToString())
- .ToArray();
-
- var eval_specs = new List<EvalSpec>()
- {
- new EvalSpec("", null, null) // for test.
- };
- foreach (var (index, (eval_spec_name, eval_input_fn)) in enumerate(zip(eval_spec_names, eval_input_fns).ToList()))
- {
- var exporter_name = index == 0 ? final_exporter_name : $"{final_exporter_name}_{eval_spec_name}";
- var exporter = tf.estimator.FinalExporter(name: exporter_name, serving_input_receiver_fn: predict_input_fn);
- eval_specs.Add(tf.estimator.EvalSpec(name: eval_spec_name,
- input_fn: eval_input_fn,
- exporters: exporter));
- }
-
- if (eval_on_train_data)
- throw new NotImplementedException("");
-
- return (train_spec, eval_specs.ToArray());
- }
- }
- }
|