Browse Source

input_fn #387

tags/v0.12
Oceania2018 6 years ago
parent
commit
41432600c8
13 changed files with 259 additions and 12 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/APIs/tf.estimator.cs
  2. +27
    -3
      src/TensorFlowNET.Core/Estimators/Estimator.cs
  3. +4
    -3
      src/TensorFlowNET.Core/Estimators/TrainSpec.cs
  4. +3
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +40
    -2
      src/TensorFlowNET.Core/Train/TrainingUtil.cs
  6. +1
    -0
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  7. +25
    -0
      src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs
  8. +40
    -0
      src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs
  9. +56
    -0
      src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
  10. +2
    -1
      src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs
  11. +50
    -0
      src/TensorFlowNET.Models/ObjectDetection/Inputs.cs
  12. +4
    -0
      src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs
  13. +5
    -2
      src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs

+ 2
- 1
src/TensorFlowNET.Core/APIs/tf.estimator.cs View File

@@ -17,6 +17,7 @@
using System;
using static Tensorflow.Binding;
using Tensorflow.Estimators;
using Tensorflow.Data;

namespace Tensorflow
{
@@ -35,7 +36,7 @@ namespace Tensorflow
public void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec)
=> Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec);

public TrainSpec TrainSpec(Action input_fn, int max_steps)
public TrainSpec TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps)
=> new TrainSpec(input_fn: input_fn, max_steps: max_steps);

/// <summary>


+ 27
- 3
src/TensorFlowNET.Core/Estimators/Estimator.cs View File

@@ -3,6 +3,8 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Text;
using Tensorflow.Data;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow.Estimators
@@ -30,7 +32,7 @@ namespace Tensorflow.Estimators
_model_fn = model_fn;
}

public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null,
public Estimator train(Func<DatasetV1Adapter> input_fn, int max_steps = 1, Action[] hooks = null,
_NewCheckpointListenerForEvaluate[] saving_listeners = null)
{
if(max_steps > 0)
@@ -56,19 +58,41 @@ namespace Tensorflow.Estimators
return cp.AllModelCheckpointPaths.Count - 1;
}

private void _train_model(Action input_fn)
private void _train_model(Func<DatasetV1Adapter> input_fn)
{
_train_model_default(input_fn);
}

private void _train_model_default(Action input_fn)
private void _train_model_default(Func<DatasetV1Adapter> input_fn)
{
using (var g = tf.Graph().as_default())
{
var global_step_tensor = _create_and_assert_global_step(g);

// Skip creating a read variable if _create_and_assert_global_step
// returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
if (global_step_tensor != null)
TrainingUtil._get_or_create_global_step_read(g);

_get_features_and_labels_from_input_fn(input_fn, "train");
}
}

private void _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
{
_call_input_fn(input_fn, mode);
}

/// <summary>
/// Calls the input function.
/// </summary>
/// <param name="input_fn"></param>
/// <param name="mode"></param>
private void _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
{
input_fn();
}

private Tensor _create_and_assert_global_step(Graph graph)
{
var step = _create_global_step(graph);


+ 4
- 3
src/TensorFlowNET.Core/Estimators/TrainSpec.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;

namespace Tensorflow.Estimators
{
@@ -9,10 +10,10 @@ namespace Tensorflow.Estimators
int _max_steps;
public int max_steps => _max_steps;

Action _input_fn;
public Action input_fn => _input_fn;
Func<DatasetV1Adapter> _input_fn;
public Func<DatasetV1Adapter> input_fn => _input_fn;

public TrainSpec(Action input_fn, int max_steps)
public TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps)
{
_max_steps = max_steps;
_input_fn = input_fn;


+ 3
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -434,6 +434,9 @@ namespace Tensorflow
case List<RefVariable> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<Tensor> list:
t = list.Select(x => (T)(object)x).ToList();
break;
default:
throw new NotImplementedException($"get_collection<{typeof(T).FullName}>");
}


+ 40
- 2
src/TensorFlowNET.Core/Train/TrainingUtil.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.Train
{
public class TrainingUtil
{
public static RefVariable create_global_step(Graph graph)
public static RefVariable create_global_step(Graph graph = null)
{
graph = graph ?? ops.get_default_graph();
if (get_global_step(graph) != null)
@@ -24,7 +24,7 @@ namespace Tensorflow.Train
return v;
}

public static RefVariable get_global_step(Graph graph)
public static RefVariable get_global_step(Graph graph = null)
{
graph = graph ?? ops.get_default_graph();
RefVariable global_step_tensor = null;
@@ -47,5 +47,43 @@ namespace Tensorflow.Train
return global_step_tensor;
}

public static Tensor _get_or_create_global_step_read(Graph graph = null)
{
graph = graph ?? ops.get_default_graph();
var global_step_read_tensor = _get_global_step_read(graph);
if (global_step_read_tensor != null)
return global_step_read_tensor;

var global_step_tensor = get_global_step(graph);

if (global_step_tensor == null)
return null;

var g = graph.as_default();
g.name_scope(null);
g.name_scope(global_step_tensor.op.name + "/");
// using initialized_value to ensure that global_step is initialized before
// this run. This is needed for example Estimator makes all model_fn build
// under global_step_read_tensor dependency.
var global_step_value = global_step_tensor.initialized_value();
ops.add_to_collection(tf.GraphKeys.GLOBAL_STEP_READ_KEY, global_step_value + 0);

return _get_global_step_read(graph);
}

private static Tensor _get_global_step_read(Graph graph = null)
{
graph = graph ?? ops.get_default_graph();
var global_step_read_tensors = graph.get_collection<Tensor>(tf.GraphKeys.GLOBAL_STEP_READ_KEY);
if (global_step_read_tensors.Count > 1)
throw new RuntimeError($"There are multiple items in collection {tf.GraphKeys.GLOBAL_STEP_READ_KEY}. " +
"There should be only one.");

if (global_step_read_tensors.Count == 1)
return global_step_read_tensors[0];

return null;
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -122,6 +122,7 @@ namespace Tensorflow
public string TRAIN_OP => TRAIN_OP_;

public string GLOBAL_STEP => GLOBAL_STEP_;
public string GLOBAL_STEP_READ_KEY = "global_step_read_op_cache";

public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_;
/// <summary>


+ 25
- 0
src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;
using Tensorflow.Models.ObjectDetection.Protos;

namespace Tensorflow.Models.ObjectDetection
{
public class DatasetBuilder
{
public static DatasetV1Adapter build(InputReader input_reader_config,
int batch_size = 0,
Action transform_input_data_fn = null)
{
Func<Dictionary<string, Tensor>, (Dictionary<string, Tensor>, Dictionary<string, Tensor>)> transform_and_pad_input_data_fn = (tensor_dict) =>
{
return (null, null);
};

var config = input_reader_config.TfRecordInputReader;

throw new NotImplementedException("");
}
}
}

+ 40
- 0
src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs View File

@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Models.ObjectDetection.Protos;
using static Tensorflow.Models.ObjectDetection.Protos.ImageResizer;

namespace Tensorflow.Models.ObjectDetection
{
public class ImageResizerBuilder
{
public ImageResizerBuilder()
{

}

public Action build(ImageResizer image_resizer_config)
{
var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase;
if (image_resizer_oneof == ImageResizerOneofOneofCase.KeepAspectRatioResizer)
{
var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer;
var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod);
var per_channel_pad_value = new[] { 0, 0, 0 };
//if (keep_aspect_ratio_config.PerChannelPadValue != null)
//per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue };
}
else
{
throw new NotImplementedException("");
}

return null;
}

private ResizeType _tf_resize_method(ResizeType resize_method)
{
return resize_method;
}
}
}

+ 56
- 0
src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs View File

@@ -0,0 +1,56 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Models.ObjectDetection.MetaArchitectures;
using Tensorflow.Models.ObjectDetection.Protos;
using static Tensorflow.Models.ObjectDetection.Protos.DetectionModel;

namespace Tensorflow.Models.ObjectDetection
{
public class ModelBuilder
{
ImageResizerBuilder image_resizer_builder;

public ModelBuilder()
{
image_resizer_builder = new ImageResizerBuilder();
}

/// <summary>
/// Builds a DetectionModel based on the model config.
/// </summary>
/// <param name="model_config">A model.proto object containing the config for the desired DetectionModel.</param>
/// <param name="is_training">True if this model is being built for training purposes.</param>
/// <param name="add_summaries">Whether to add tensorflow summaries in the model graph.</param>
/// <returns>DetectionModel based on the config.</returns>
public FasterRCNNMetaArch build(DetectionModel model_config, bool is_training, bool add_summaries = true)
{
var meta_architecture = model_config.ModelCase;
if (meta_architecture == ModelOneofCase.Ssd)
throw new NotImplementedException("");
else if (meta_architecture == ModelOneofCase.FasterRcnn)
return _build_faster_rcnn_model(model_config.FasterRcnn, is_training, add_summaries);

throw new ValueError($"Unknown meta architecture: {meta_architecture}");
}

/// <summary>
/// Builds a Faster R-CNN or R-FCN detection model based on the model config.
/// </summary>
/// <param name="frcnn_config"></param>
/// <param name="is_training"></param>
/// <param name="add_summaries"></param>
/// <returns>FasterRCNNMetaArch based on the config.</returns>
private FasterRCNNMetaArch _build_faster_rcnn_model(FasterRcnn frcnn_config, bool is_training, bool add_summaries)
{
var num_classes = frcnn_config.NumClasses;
var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer);
throw new NotImplementedException("");
}

public Action preprocess()
{
throw new NotImplementedException("");
}
}
}

+ 2
- 1
src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;
using Tensorflow.Estimators;

namespace Tensorflow.Models.ObjectDetection
@@ -8,7 +9,7 @@ namespace Tensorflow.Models.ObjectDetection
public class TrainAndEvalDict
{
public Estimator estimator { get; set; }
public Action train_input_fn { get; set; }
public Func<DatasetV1Adapter> train_input_fn { get; set; }
public Action[] eval_input_fns { get; set; }
public string[] eval_input_names { get; set; }
public Action eval_on_train_input_fn { get; set; }


+ 50
- 0
src/TensorFlowNET.Models/ObjectDetection/Inputs.cs View File

@@ -0,0 +1,50 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;
using Tensorflow.Estimators;
using Tensorflow.Models.ObjectDetection.MetaArchitectures;
using Tensorflow.Models.ObjectDetection.Protos;

namespace Tensorflow.Models.ObjectDetection
{
public class Inputs
{
ModelBuilder modelBuilder;
Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>> INPUT_BUILDER_UTIL_MAP;

public Inputs()
{
modelBuilder = new ModelBuilder();
INPUT_BUILDER_UTIL_MAP = new Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>>();
INPUT_BUILDER_UTIL_MAP["model_build"] = modelBuilder.build;
}

public Func<DatasetV1Adapter> create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config)
{
Func<DatasetV1Adapter> _train_input_fn = () =>
{
return train_input(train_config, train_input_config, model_config);
};

return _train_input_fn;
}

/// <summary>
/// Returns `features` and `labels` tensor dictionaries for training.
/// </summary>
/// <param name="train_config"></param>
/// <param name="train_input_config"></param>
/// <param name="model_config"></param>
/// <returns></returns>
public DatasetV1Adapter train_input(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config)
{
var arch = INPUT_BUILDER_UTIL_MAP["model_build"](model_config, true, true);
Func<Tensor, (Tensor, Tensor)> model_preprocess_fn = arch.preprocess;

var dataset = DatasetBuilder.build(train_input_config);

return dataset;
}
}
}

+ 4
- 0
src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs View File

@@ -6,5 +6,9 @@ namespace Tensorflow.Models.ObjectDetection.MetaArchitectures
{
public class FasterRCNNMetaArch
{
public (Tensor, Tensor) preprocess(Tensor tensor)
{
throw new NotImplementedException("");
}
}
}

+ 5
- 2
src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs View File

@@ -6,11 +6,14 @@ using Tensorflow.Estimators;
using System.Linq;
using Tensorflow.Contrib.Train;
using Tensorflow.Models.ObjectDetection.Utils;
using Tensorflow.Data;

namespace Tensorflow.Models.ObjectDetection
{
public class ModelLib
{
Inputs inputs = new Inputs();

public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config,
HParams hparams = null,
string pipeline_config_path = null,
@@ -21,7 +24,7 @@ namespace Tensorflow.Models.ObjectDetection
var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path);

// Create the input functions for TRAIN/EVAL/PREDICT.
Action train_input_fn = () => { };
Func<DatasetV1Adapter> train_input_fn = inputs.create_train_input_fn(config.TrainConfig, config.TrainInputReader, config.Model);

var eval_input_configs = config.EvalInputReader;

@@ -44,7 +47,7 @@ namespace Tensorflow.Models.ObjectDetection
};
}

public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Action train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn,
public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Func<DatasetV1Adapter> train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn,
Action predict_input_fn, int train_steps, bool eval_on_train_data = false,
string final_exporter_name = "Servo", string[] eval_spec_names = null)
{


Loading…
Cancel
Save