You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelLib.cs 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using static Tensorflow.Binding;
  5. using Tensorflow.Estimators;
  6. using System.Linq;
  7. using Tensorflow.Contrib.Train;
  8. using Tensorflow.Models.ObjectDetection.Utils;
  9. namespace Tensorflow.Models.ObjectDetection
  10. {
  11. public class ModelLib
  12. {
  13. public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config,
  14. HParams hparams = null,
  15. string pipeline_config_path = null,
  16. int train_steps = 0,
  17. int sample_1_of_n_eval_examples = 0,
  18. int sample_1_of_n_eval_on_train_examples = 1)
  19. {
  20. var estimator = tf.estimator.Estimator(config: run_config);
  21. var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path);
  22. var eval_input_configs = config.EvalInputReader;
  23. var eval_input_fns = new Action[eval_input_configs.Count];
  24. return new TrainAndEvalDict
  25. {
  26. estimator = estimator,
  27. train_steps = train_steps,
  28. eval_input_fns = eval_input_fns
  29. };
  30. }
  31. public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Action train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn,
  32. Action predict_input_fn, int train_steps, bool eval_on_train_data = false,
  33. string final_exporter_name = "Servo", string[] eval_spec_names = null)
  34. {
  35. var train_spec = tf.estimator.TrainSpec(input_fn: train_input_fn, max_steps: train_steps);
  36. if (eval_spec_names == null)
  37. eval_spec_names = range(len(eval_input_fns))
  38. .Select(x => x.ToString())
  39. .ToArray();
  40. var eval_specs = new List<EvalSpec>()
  41. {
  42. new EvalSpec("", null, null) // for test.
  43. };
  44. foreach (var (index, (eval_spec_name, eval_input_fn)) in enumerate(zip(eval_spec_names, eval_input_fns).ToList()))
  45. {
  46. var exporter_name = index == 0 ? final_exporter_name : $"{final_exporter_name}_{eval_spec_name}";
  47. var exporter = tf.estimator.FinalExporter(name: exporter_name, serving_input_receiver_fn: predict_input_fn);
  48. eval_specs.Add(tf.estimator.EvalSpec(name: eval_spec_name,
  49. input_fn: eval_input_fn,
  50. exporters: exporter));
  51. }
  52. if (eval_on_train_data)
  53. throw new NotImplementedException("");
  54. return (train_spec, eval_specs.ToArray());
  55. }
  56. }
  57. }