Browse Source

train_input

tags/v0.12
Oceania2018 6 years ago
parent
commit
bb638299ee
10 changed files with 118 additions and 21 deletions
  1. +14
    -10
      src/TensorFlowNET.Core/Estimators/Estimator.cs
  2. +15
    -0
      src/TensorFlowNET.Core/Estimators/EstimatorUtil.cs
  3. +15
    -0
      src/TensorFlowNET.Models/ObjectDetection/Builders/BoxPredictorBuilder.cs
  4. +12
    -2
      src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs
  5. +14
    -2
      src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
  6. +10
    -0
      src/TensorFlowNET.Models/ObjectDetection/Core/DetectionModel.cs
  7. +1
    -5
      src/TensorFlowNET.Models/ObjectDetection/Inputs.cs
  8. +18
    -0
      src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs
  9. +9
    -2
      src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs
  10. +10
    -0
      src/TensorFlowNET.Models/ObjectDetection/Predictors/ConvolutionalBoxPredictor.cs

+ 14
- 10
src/TensorFlowNET.Core/Estimators/Estimator.cs View File

@@ -45,8 +45,9 @@ namespace Tensorflow.Estimators
}
}

_train_model(input_fn);
throw new NotImplementedException("");
var loss = _train_model(input_fn);
print($"Loss for final step: {loss}.");
return this;
}

private int _load_global_step_from_checkpoint_dir(string checkpoint_dir)
@@ -58,12 +59,12 @@ namespace Tensorflow.Estimators
return cp.AllModelCheckpointPaths.Count - 1;
}

private void _train_model(Func<DatasetV1Adapter> input_fn)
private Tensor _train_model(Func<DatasetV1Adapter> input_fn)
{
_train_model_default(input_fn);
return _train_model_default(input_fn);
}

private void _train_model_default(Func<DatasetV1Adapter> input_fn)
private Tensor _train_model_default(Func<DatasetV1Adapter> input_fn)
{
using (var g = tf.Graph().as_default())
{
@@ -74,13 +75,16 @@ namespace Tensorflow.Estimators
if (global_step_tensor != null)
TrainingUtil._get_or_create_global_step_read(g);

_get_features_and_labels_from_input_fn(input_fn, "train");
var (features, labels) = _get_features_and_labels_from_input_fn(input_fn, "train");
}

throw new NotImplementedException("");
}

private void _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
private (Dictionary<string, Tensor>, Dictionary<string, Tensor>) _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
{
_call_input_fn(input_fn, mode);
var result = _call_input_fn(input_fn, mode);
return EstimatorUtil.parse_input_fn_result(result);
}

/// <summary>
@@ -88,9 +92,9 @@ namespace Tensorflow.Estimators
/// </summary>
/// <param name="input_fn"></param>
/// <param name="mode"></param>
private void _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
private DatasetV1Adapter _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
{
input_fn();
return input_fn();
}

private Tensor _create_and_assert_global_step(Graph graph)


+ 15
- 0
src/TensorFlowNET.Core/Estimators/EstimatorUtil.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;

namespace Tensorflow.Estimators
{
public class EstimatorUtil
{
public static (Dictionary<string, Tensor>, Dictionary<string, Tensor>) parse_input_fn_result(DatasetV1Adapter result)
{
throw new NotImplementedException("");
}
}
}

+ 15
- 0
src/TensorFlowNET.Models/ObjectDetection/Builders/BoxPredictorBuilder.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Models.ObjectDetection
{
public class BoxPredictorBuilder
{
ConvolutionalBoxPredictor _first_stage_box_predictor;
public ConvolutionalBoxPredictor build_convolutional_box_predictor()
{
throw new NotImplementedException("");
}
}
}

+ 12
- 2
src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs View File

@@ -13,6 +13,11 @@ namespace Tensorflow.Models.ObjectDetection

}

/// <summary>
/// Builds callable for image resizing operations.
/// </summary>
/// <param name="image_resizer_config"></param>
/// <returns></returns>
public Action build(ImageResizer image_resizer_config)
{
var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase;
@@ -21,8 +26,13 @@ namespace Tensorflow.Models.ObjectDetection
var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer;
var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod);
var per_channel_pad_value = new[] { 0, 0, 0 };
//if (keep_aspect_ratio_config.PerChannelPadValue != null)
//per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue };
if (keep_aspect_ratio_config.PerChannelPadValue.Count > 0)
throw new NotImplementedException("");
// per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue. };
return () =>
{

};
}
else
{


+ 14
- 2
src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs View File

@@ -1,7 +1,6 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Models.ObjectDetection.MetaArchitectures;
using Tensorflow.Models.ObjectDetection.Protos;
using static Tensorflow.Models.ObjectDetection.Protos.DetectionModel;

@@ -45,7 +44,20 @@ namespace Tensorflow.Models.ObjectDetection
{
var num_classes = frcnn_config.NumClasses;
var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer);
throw new NotImplementedException("");

var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate;
var number_of_stages = frcnn_config.NumberOfStages;

return new FasterRCNNMetaArch(new FasterRCNNInitArgs
{
is_training = is_training,
num_classes = num_classes,
image_resizer_fn = image_resizer_fn,
feature_extractor = () => { throw new NotImplementedException(""); },
number_of_stage = number_of_stages,
first_stage_anchor_generator = null,
first_stage_atrous_rate = first_stage_atrous_rate
});
}

public Action preprocess()


+ 10
- 0
src/TensorFlowNET.Models/ObjectDetection/Core/DetectionModel.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Models.ObjectDetection.Core
{
public abstract class DetectionModel
{
}
}

+ 1
- 5
src/TensorFlowNET.Models/ObjectDetection/Inputs.cs View File

@@ -2,8 +2,6 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;
using Tensorflow.Estimators;
using Tensorflow.Models.ObjectDetection.MetaArchitectures;
using Tensorflow.Models.ObjectDetection.Protos;

namespace Tensorflow.Models.ObjectDetection
@@ -23,9 +21,7 @@ namespace Tensorflow.Models.ObjectDetection
public Func<DatasetV1Adapter> create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config)
{
Func<DatasetV1Adapter> _train_input_fn = () =>
{
return train_input(train_config, train_input_config, model_config);
};
train_input(train_config, train_input_config, model_config);

return _train_input_fn;
}


+ 18
- 0
src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Models.ObjectDetection
{
public class FasterRCNNInitArgs
{
public bool is_training { get; set; }
public int num_classes { get; set; }
public Action image_resizer_fn { get; set; }
public Action feature_extractor { get; set; }
public int number_of_stage { get; set; }
public object first_stage_anchor_generator { get; set; }
public object first_stage_target_assigner { get; set; }
public int first_stage_atrous_rate { get; set; }
}
}

+ 9
- 2
src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs View File

@@ -2,10 +2,17 @@
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Models.ObjectDetection.MetaArchitectures
namespace Tensorflow.Models.ObjectDetection
{
public class FasterRCNNMetaArch
public class FasterRCNNMetaArch : Core.DetectionModel
{
FasterRCNNInitArgs _args;

public FasterRCNNMetaArch(FasterRCNNInitArgs args)
{
_args = args;
}

public (Tensor, Tensor) preprocess(Tensor tensor)
{
throw new NotImplementedException("");


+ 10
- 0
src/TensorFlowNET.Models/ObjectDetection/Predictors/ConvolutionalBoxPredictor.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Models.ObjectDetection
{
public class ConvolutionalBoxPredictor
{
}
}

Loading…
Cancel
Save