|
|
|
@@ -45,8 +45,9 @@ namespace Tensorflow.Estimators |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
_train_model(input_fn); |
|
|
|
throw new NotImplementedException(""); |
|
|
|
var loss = _train_model(input_fn); |
|
|
|
print($"Loss for final step: {loss}."); |
|
|
|
return this; |
|
|
|
} |
|
|
|
|
|
|
|
private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) |
|
|
|
@@ -58,12 +59,12 @@ namespace Tensorflow.Estimators |
|
|
|
return cp.AllModelCheckpointPaths.Count - 1; |
|
|
|
} |
|
|
|
|
|
|
|
private void _train_model(Func<DatasetV1Adapter> input_fn) |
|
|
|
private Tensor _train_model(Func<DatasetV1Adapter> input_fn) |
|
|
|
{ |
|
|
|
_train_model_default(input_fn); |
|
|
|
return _train_model_default(input_fn); |
|
|
|
} |
|
|
|
|
|
|
|
private void _train_model_default(Func<DatasetV1Adapter> input_fn) |
|
|
|
private Tensor _train_model_default(Func<DatasetV1Adapter> input_fn) |
|
|
|
{ |
|
|
|
using (var g = tf.Graph().as_default()) |
|
|
|
{ |
|
|
|
@@ -74,13 +75,16 @@ namespace Tensorflow.Estimators |
|
|
|
if (global_step_tensor != null) |
|
|
|
TrainingUtil._get_or_create_global_step_read(g); |
|
|
|
|
|
|
|
_get_features_and_labels_from_input_fn(input_fn, "train"); |
|
|
|
var (features, labels) = _get_features_and_labels_from_input_fn(input_fn, "train"); |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
|
|
|
|
private void _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode) |
|
|
|
private (Dictionary<string, Tensor>, Dictionary<string, Tensor>) _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode) |
|
|
|
{ |
|
|
|
_call_input_fn(input_fn, mode); |
|
|
|
var result = _call_input_fn(input_fn, mode); |
|
|
|
return EstimatorUtil.parse_input_fn_result(result); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
@@ -88,9 +92,9 @@ namespace Tensorflow.Estimators |
|
|
|
/// </summary> |
|
|
|
/// <param name="input_fn"></param> |
|
|
|
/// <param name="mode"></param> |
|
|
|
private void _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode) |
|
|
|
private DatasetV1Adapter _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode) |
|
|
|
{ |
|
|
|
input_fn(); |
|
|
|
return input_fn(); |
|
|
|
} |
|
|
|
|
|
|
|
private Tensor _create_and_assert_global_step(Graph graph) |
|
|
|
|