Browse Source

input_fn #387

tags/v0.12
Oceania2018 6 years ago
parent
commit
41432600c8
13 changed files with 259 additions and 12 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/APIs/tf.estimator.cs
  2. +27
    -3
      src/TensorFlowNET.Core/Estimators/Estimator.cs
  3. +4
    -3
      src/TensorFlowNET.Core/Estimators/TrainSpec.cs
  4. +3
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +40
    -2
      src/TensorFlowNET.Core/Train/TrainingUtil.cs
  6. +1
    -0
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  7. +25
    -0
      src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs
  8. +40
    -0
      src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs
  9. +56
    -0
      src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
  10. +2
    -1
      src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs
  11. +50
    -0
      src/TensorFlowNET.Models/ObjectDetection/Inputs.cs
  12. +4
    -0
      src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs
  13. +5
    -2
      src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs

+ 2
- 1
src/TensorFlowNET.Core/APIs/tf.estimator.cs View File

@@ -17,6 +17,7 @@
using System; using System;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow.Estimators; using Tensorflow.Estimators;
using Tensorflow.Data;


namespace Tensorflow namespace Tensorflow
{ {
@@ -35,7 +36,7 @@ namespace Tensorflow
public void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) public void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec)
=> Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); => Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec);


public TrainSpec TrainSpec(Action input_fn, int max_steps)
public TrainSpec TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps)
=> new TrainSpec(input_fn: input_fn, max_steps: max_steps); => new TrainSpec(input_fn: input_fn, max_steps: max_steps);


/// <summary> /// <summary>


+ 27
- 3
src/TensorFlowNET.Core/Estimators/Estimator.cs View File

@@ -3,6 +3,8 @@ using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.Text; using System.Text;
using Tensorflow.Data;
using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Estimators namespace Tensorflow.Estimators
@@ -30,7 +32,7 @@ namespace Tensorflow.Estimators
_model_fn = model_fn; _model_fn = model_fn;
} }


public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null,
public Estimator train(Func<DatasetV1Adapter> input_fn, int max_steps = 1, Action[] hooks = null,
_NewCheckpointListenerForEvaluate[] saving_listeners = null) _NewCheckpointListenerForEvaluate[] saving_listeners = null)
{ {
if(max_steps > 0) if(max_steps > 0)
@@ -56,19 +58,41 @@ namespace Tensorflow.Estimators
return cp.AllModelCheckpointPaths.Count - 1; return cp.AllModelCheckpointPaths.Count - 1;
} }


private void _train_model(Action input_fn)
private void _train_model(Func<DatasetV1Adapter> input_fn)
{ {
_train_model_default(input_fn); _train_model_default(input_fn);
} }


private void _train_model_default(Action input_fn)
private void _train_model_default(Func<DatasetV1Adapter> input_fn)
{ {
using (var g = tf.Graph().as_default()) using (var g = tf.Graph().as_default())
{ {
var global_step_tensor = _create_and_assert_global_step(g); var global_step_tensor = _create_and_assert_global_step(g);

// Skip creating a read variable if _create_and_assert_global_step
// returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
if (global_step_tensor != null)
TrainingUtil._get_or_create_global_step_read(g);

_get_features_and_labels_from_input_fn(input_fn, "train");
} }
} }


private void _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
{
_call_input_fn(input_fn, mode);
}

/// <summary>
/// Calls the input function.
/// </summary>
/// <param name="input_fn"></param>
/// <param name="mode"></param>
private void _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
{
input_fn();
}

private Tensor _create_and_assert_global_step(Graph graph) private Tensor _create_and_assert_global_step(Graph graph)
{ {
var step = _create_global_step(graph); var step = _create_global_step(graph);


+ 4
- 3
src/TensorFlowNET.Core/Estimators/TrainSpec.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Data;


namespace Tensorflow.Estimators namespace Tensorflow.Estimators
{ {
@@ -9,10 +10,10 @@ namespace Tensorflow.Estimators
int _max_steps; int _max_steps;
public int max_steps => _max_steps; public int max_steps => _max_steps;


Action _input_fn;
public Action input_fn => _input_fn;
Func<DatasetV1Adapter> _input_fn;
public Func<DatasetV1Adapter> input_fn => _input_fn;


public TrainSpec(Action input_fn, int max_steps)
public TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps)
{ {
_max_steps = max_steps; _max_steps = max_steps;
_input_fn = input_fn; _input_fn = input_fn;


+ 3
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -434,6 +434,9 @@ namespace Tensorflow
case List<RefVariable> list: case List<RefVariable> list:
t = list.Select(x => (T)(object)x).ToList(); t = list.Select(x => (T)(object)x).ToList();
break; break;
case List<Tensor> list:
t = list.Select(x => (T)(object)x).ToList();
break;
default: default:
throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); throw new NotImplementedException($"get_collection<{typeof(T).FullName}>");
} }


+ 40
- 2
src/TensorFlowNET.Core/Train/TrainingUtil.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.Train
{ {
public class TrainingUtil public class TrainingUtil
{ {
public static RefVariable create_global_step(Graph graph)
public static RefVariable create_global_step(Graph graph = null)
{ {
graph = graph ?? ops.get_default_graph(); graph = graph ?? ops.get_default_graph();
if (get_global_step(graph) != null) if (get_global_step(graph) != null)
@@ -24,7 +24,7 @@ namespace Tensorflow.Train
return v; return v;
} }


public static RefVariable get_global_step(Graph graph)
public static RefVariable get_global_step(Graph graph = null)
{ {
graph = graph ?? ops.get_default_graph(); graph = graph ?? ops.get_default_graph();
RefVariable global_step_tensor = null; RefVariable global_step_tensor = null;
@@ -47,5 +47,43 @@ namespace Tensorflow.Train
return global_step_tensor; return global_step_tensor;
} }

public static Tensor _get_or_create_global_step_read(Graph graph = null)
{
graph = graph ?? ops.get_default_graph();
var global_step_read_tensor = _get_global_step_read(graph);
if (global_step_read_tensor != null)
return global_step_read_tensor;

var global_step_tensor = get_global_step(graph);

if (global_step_tensor == null)
return null;

var g = graph.as_default();
g.name_scope(null);
g.name_scope(global_step_tensor.op.name + "/");
// using initialized_value to ensure that global_step is initialized before
// this run. This is needed for example Estimator makes all model_fn build
// under global_step_read_tensor dependency.
var global_step_value = global_step_tensor.initialized_value();
ops.add_to_collection(tf.GraphKeys.GLOBAL_STEP_READ_KEY, global_step_value + 0);

return _get_global_step_read(graph);
}

private static Tensor _get_global_step_read(Graph graph = null)
{
graph = graph ?? ops.get_default_graph();
var global_step_read_tensors = graph.get_collection<Tensor>(tf.GraphKeys.GLOBAL_STEP_READ_KEY);
if (global_step_read_tensors.Count > 1)
throw new RuntimeError($"There are multiple items in collection {tf.GraphKeys.GLOBAL_STEP_READ_KEY}. " +
"There should be only one.");

if (global_step_read_tensors.Count == 1)
return global_step_read_tensors[0];

return null;
}
} }
} }

+ 1
- 0
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -122,6 +122,7 @@ namespace Tensorflow
public string TRAIN_OP => TRAIN_OP_; public string TRAIN_OP => TRAIN_OP_;


public string GLOBAL_STEP => GLOBAL_STEP_; public string GLOBAL_STEP => GLOBAL_STEP_;
public string GLOBAL_STEP_READ_KEY = "global_step_read_op_cache";


public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_;
/// <summary> /// <summary>


+ 25
- 0
src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;
using Tensorflow.Models.ObjectDetection.Protos;

namespace Tensorflow.Models.ObjectDetection
{
public class DatasetBuilder
{
public static DatasetV1Adapter build(InputReader input_reader_config,
int batch_size = 0,
Action transform_input_data_fn = null)
{
Func<Dictionary<string, Tensor>, (Dictionary<string, Tensor>, Dictionary<string, Tensor>)> transform_and_pad_input_data_fn = (tensor_dict) =>
{
return (null, null);
};

var config = input_reader_config.TfRecordInputReader;

throw new NotImplementedException("");
}
}
}

+ 40
- 0
src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs View File

@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Models.ObjectDetection.Protos;
using static Tensorflow.Models.ObjectDetection.Protos.ImageResizer;

namespace Tensorflow.Models.ObjectDetection
{
public class ImageResizerBuilder
{
public ImageResizerBuilder()
{

}

public Action build(ImageResizer image_resizer_config)
{
var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase;
if (image_resizer_oneof == ImageResizerOneofOneofCase.KeepAspectRatioResizer)
{
var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer;
var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod);
var per_channel_pad_value = new[] { 0, 0, 0 };
//if (keep_aspect_ratio_config.PerChannelPadValue != null)
//per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue };
}
else
{
throw new NotImplementedException("");
}

return null;
}

private ResizeType _tf_resize_method(ResizeType resize_method)
{
return resize_method;
}
}
}

+ 56
- 0
src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs View File

@@ -0,0 +1,56 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Models.ObjectDetection.MetaArchitectures;
using Tensorflow.Models.ObjectDetection.Protos;
using static Tensorflow.Models.ObjectDetection.Protos.DetectionModel;

namespace Tensorflow.Models.ObjectDetection
{
public class ModelBuilder
{
ImageResizerBuilder image_resizer_builder;

public ModelBuilder()
{
image_resizer_builder = new ImageResizerBuilder();
}

/// <summary>
/// Builds a DetectionModel based on the model config.
/// </summary>
/// <param name="model_config">A model.proto object containing the config for the desired DetectionModel.</param>
/// <param name="is_training">True if this model is being built for training purposes.</param>
/// <param name="add_summaries">Whether to add tensorflow summaries in the model graph.</param>
/// <returns>DetectionModel based on the config.</returns>
public FasterRCNNMetaArch build(DetectionModel model_config, bool is_training, bool add_summaries = true)
{
var meta_architecture = model_config.ModelCase;
if (meta_architecture == ModelOneofCase.Ssd)
throw new NotImplementedException("");
else if (meta_architecture == ModelOneofCase.FasterRcnn)
return _build_faster_rcnn_model(model_config.FasterRcnn, is_training, add_summaries);

throw new ValueError($"Unknown meta architecture: {meta_architecture}");
}

/// <summary>
/// Builds a Faster R-CNN or R-FCN detection model based on the model config.
/// </summary>
/// <param name="frcnn_config"></param>
/// <param name="is_training"></param>
/// <param name="add_summaries"></param>
/// <returns>FasterRCNNMetaArch based on the config.</returns>
private FasterRCNNMetaArch _build_faster_rcnn_model(FasterRcnn frcnn_config, bool is_training, bool add_summaries)
{
var num_classes = frcnn_config.NumClasses;
var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer);
throw new NotImplementedException("");
}

public Action preprocess()
{
throw new NotImplementedException("");
}
}
}

+ 2
- 1
src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Data;
using Tensorflow.Estimators; using Tensorflow.Estimators;


namespace Tensorflow.Models.ObjectDetection namespace Tensorflow.Models.ObjectDetection
@@ -8,7 +9,7 @@ namespace Tensorflow.Models.ObjectDetection
public class TrainAndEvalDict public class TrainAndEvalDict
{ {
public Estimator estimator { get; set; } public Estimator estimator { get; set; }
public Action train_input_fn { get; set; }
public Func<DatasetV1Adapter> train_input_fn { get; set; }
public Action[] eval_input_fns { get; set; } public Action[] eval_input_fns { get; set; }
public string[] eval_input_names { get; set; } public string[] eval_input_names { get; set; }
public Action eval_on_train_input_fn { get; set; } public Action eval_on_train_input_fn { get; set; }


+ 50
- 0
src/TensorFlowNET.Models/ObjectDetection/Inputs.cs View File

@@ -0,0 +1,50 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;
using Tensorflow.Estimators;
using Tensorflow.Models.ObjectDetection.MetaArchitectures;
using Tensorflow.Models.ObjectDetection.Protos;

namespace Tensorflow.Models.ObjectDetection
{
public class Inputs
{
ModelBuilder modelBuilder;
Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>> INPUT_BUILDER_UTIL_MAP;

public Inputs()
{
modelBuilder = new ModelBuilder();
INPUT_BUILDER_UTIL_MAP = new Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>>();
INPUT_BUILDER_UTIL_MAP["model_build"] = modelBuilder.build;
}

public Func<DatasetV1Adapter> create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config)
{
Func<DatasetV1Adapter> _train_input_fn = () =>
{
return train_input(train_config, train_input_config, model_config);
};

return _train_input_fn;
}

/// <summary>
/// Returns `features` and `labels` tensor dictionaries for training.
/// </summary>
/// <param name="train_config"></param>
/// <param name="train_input_config"></param>
/// <param name="model_config"></param>
/// <returns></returns>
public DatasetV1Adapter train_input(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config)
{
var arch = INPUT_BUILDER_UTIL_MAP["model_build"](model_config, true, true);
Func<Tensor, (Tensor, Tensor)> model_preprocess_fn = arch.preprocess;

var dataset = DatasetBuilder.build(train_input_config);

return dataset;
}
}
}

+ 4
- 0
src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs View File

@@ -6,5 +6,9 @@ namespace Tensorflow.Models.ObjectDetection.MetaArchitectures
{ {
public class FasterRCNNMetaArch public class FasterRCNNMetaArch
{ {
public (Tensor, Tensor) preprocess(Tensor tensor)
{
throw new NotImplementedException("");
}
} }
} }

+ 5
- 2
src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs View File

@@ -6,11 +6,14 @@ using Tensorflow.Estimators;
using System.Linq; using System.Linq;
using Tensorflow.Contrib.Train; using Tensorflow.Contrib.Train;
using Tensorflow.Models.ObjectDetection.Utils; using Tensorflow.Models.ObjectDetection.Utils;
using Tensorflow.Data;


namespace Tensorflow.Models.ObjectDetection namespace Tensorflow.Models.ObjectDetection
{ {
public class ModelLib public class ModelLib
{ {
Inputs inputs = new Inputs();

public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config, public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config,
HParams hparams = null, HParams hparams = null,
string pipeline_config_path = null, string pipeline_config_path = null,
@@ -21,7 +24,7 @@ namespace Tensorflow.Models.ObjectDetection
var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path); var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path);


// Create the input functions for TRAIN/EVAL/PREDICT. // Create the input functions for TRAIN/EVAL/PREDICT.
Action train_input_fn = () => { };
Func<DatasetV1Adapter> train_input_fn = inputs.create_train_input_fn(config.TrainConfig, config.TrainInputReader, config.Model);


var eval_input_configs = config.EvalInputReader; var eval_input_configs = config.EvalInputReader;


@@ -44,7 +47,7 @@ namespace Tensorflow.Models.ObjectDetection
}; };
} }


public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Action train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn,
public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Func<DatasetV1Adapter> train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn,
Action predict_input_fn, int train_steps, bool eval_on_train_data = false, Action predict_input_fn, int train_steps, bool eval_on_train_data = false,
string final_exporter_name = "Servo", string[] eval_spec_names = null) string final_exporter_name = "Servo", string[] eval_spec_names = null)
{ {


Loading…
Cancel
Save