Browse Source

add op_def_registry to load all registered op list.

add RuntimeError exception class.
add meta_graph related classes.
tags/v0.8.0
haiping008 6 years ago
parent
commit
c7cf8b6084
11 changed files with 3371 additions and 20 deletions
  1. +19
    -0
      src/TensorFlowNET.Core/Exceptions/RuntimeError.cs
  2. +135
    -0
      src/TensorFlowNET.Core/Framework/meta_graph.py.cs
  3. +27
    -0
      src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Export.cs
  5. +264
    -0
      src/TensorFlowNET.Core/Protobuf/CheckpointState.cs
  6. +2679
    -0
      src/TensorFlowNET.Core/Protobuf/MetaGraph.cs
  7. +13
    -15
      src/TensorFlowNET.Core/Protobuf/README.md
  8. +5
    -0
      src/TensorFlowNET.Core/Python.cs
  9. +117
    -3
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  10. +109
    -0
      src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs
  11. +2
    -1
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 19
- 0
src/TensorFlowNET.Core/Exceptions/RuntimeError.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class RuntimeError : Exception
{
public RuntimeError() : base()
{

}

public RuntimeError(string message) : base(message)
{

}
}
}

+ 135
- 0
src/TensorFlowNET.Core/Framework/meta_graph.py.cs View File

@@ -0,0 +1,135 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.MetaGraphDef.Types;

namespace Tensorflow
{
public class meta_graph
{
/// <summary>
/// Returns `MetaGraphDef` proto. Optionally writes it to filename.
/// </summary>
/// <param name="filename"></param>
/// <param name="graph_def"></param>
/// <param name="as_text"></param>
/// <param name="unbound_inputs_col_name"></param>
/// <param name="clear_devices"></param>
/// <param name="saver_def"></param>
/// <param name="clear_extraneous_savers"></param>
/// <param name="strip_default_attrs"></param>
/// <param name="meta_info_def"></param>
/// <returns></returns>
public static MetaGraphDef export_scoped_meta_graph(string filename = "",
GraphDef graph_def = null,
bool as_text = false,
string unbound_inputs_col_name = "unbound_inputs",
bool clear_devices = false,
SaverDef saver_def = null,
bool clear_extraneous_savers= false,
bool strip_default_attrs= false,
byte[] meta_info_def = null)
{
var graph = ops.get_default_graph();

var var_list = new Dictionary<string, RefVariable>();
var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES);

foreach(var v in variables as RefVariable[])
{
var_list[v.name] = v;
}

var scoped_meta_graph_def = create_meta_graph_def(
graph_def: graph_def,
export_scope: "",
exclude_nodes: "",
clear_extraneous_savers: clear_extraneous_savers,
saver_def: saver_def,
strip_default_attrs: strip_default_attrs);

throw new NotImplementedException("meta_graph.export_scoped_meta_graph");
}

private static bool _should_include_node()
{
return true;
}

private static byte[] create_meta_graph_def(MetaInfoDef meta_info_def = null,
GraphDef graph_def = null,
string export_scope = "",
string exclude_nodes = "",
SaverDef saver_def = null,
bool clear_extraneous_savers = false,
bool strip_default_attrs = false)
{
// Sets graph to default graph if it's not passed in.
var graph = ops.get_default_graph();
// Creates a MetaGraphDef proto.
var meta_graph_def = new MetaGraphDef();
if (meta_info_def == null)
meta_info_def = new MetaInfoDef();

// Set the tf version strings to the current tf build.
meta_info_def.TensorflowVersion = tf.VERSION;
meta_info_def.TensorflowGitVersion = "unknown";
meta_graph_def.MetaInfoDef = meta_info_def;

// Adds graph_def or the default.
if (graph_def == null)
meta_graph_def.GraphDef = graph._as_graph_def(add_shapes: true);
else
meta_graph_def.GraphDef = graph_def;

// Fills in meta_info_def.stripped_op_list using the ops from graph_def.
if (meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0)
meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef);

throw new NotImplementedException("create_meta_graph_def");
}

private static OpList stripped_op_list_for_graph(GraphDef graph_def)
{
var used_ops = ops_used_by_graph_def(graph_def);

// Verify that all used ops are registered.
// var registered_ops = op_def_registry.get_registered_ops();

var op_list = new OpList();
/*used_ops.OrderBy(x => x).Select(x => {

}).ToArray();*/

return op_list;
}

/// <summary>
/// Collect the list of ops used by a graph.
/// </summary>
/// <param name="graph_def"></param>
/// <returns></returns>
private static string[] ops_used_by_graph_def(GraphDef graph_def)
{
var used_ops = new List<string>();

Action<string> mark_op_as_used = (op) =>
{
if (!used_ops.Contains(op))
{

}

used_ops.Add(op);
};

foreach (var node in graph_def.Node)
{
mark_op_as_used(node.Op);
}

return used_ops.ToArray();
}
}
}

+ 27
- 0
src/TensorFlowNET.Core/Framework/op_def_registry.py.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class op_def_registry
{
private static Dictionary<string, OpDef> _registered_ops;

public static Dictionary<string, OpDef> get_registered_ops()
{
if(_registered_ops == null)
{
_registered_ops = new Dictionary<string, OpDef>();
var handle = c_api.TF_GetAllOpList();
var buffer = new Buffer(handle);
var op_list = OpList.Parser.ParseFrom(buffer);

foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;
}

return _registered_ops;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.Export.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow
return buffer; return buffer;
} }


public GraphDef _as_graph_def()
public GraphDef _as_graph_def(bool add_shapes = false)
{ {
var buffer = ToGraphDef(Status); var buffer = ToGraphDef(Status);
Status.Check(); Status.Check();


+ 264
- 0
src/TensorFlowNET.Core/Protobuf/CheckpointState.cs View File

@@ -0,0 +1,264 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: checkpoint_state.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace Tensorflow {

/// <summary>Holder for reflection information generated from checkpoint_state.proto</summary>
public static partial class CheckpointStateReflection {

#region Descriptor
/// <summary>File descriptor for checkpoint_state.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static CheckpointStateReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"ChZjaGVja3BvaW50X3N0YXRlLnByb3RvEgp0ZW5zb3JmbG93Ip8BCg9DaGVj",
"a3BvaW50U3RhdGUSHQoVbW9kZWxfY2hlY2twb2ludF9wYXRoGAEgASgJEiIK",
"GmFsbF9tb2RlbF9jaGVja3BvaW50X3BhdGhzGAIgAygJEicKH2FsbF9tb2Rl",
"bF9jaGVja3BvaW50X3RpbWVzdGFtcHMYAyADKAESIAoYbGFzdF9wcmVzZXJ2",
"ZWRfdGltZXN0YW1wGAQgASgBQgP4AQFiBnByb3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CheckpointState), global::Tensorflow.CheckpointState.Parser, new[]{ "ModelCheckpointPath", "AllModelCheckpointPaths", "AllModelCheckpointTimestamps", "LastPreservedTimestamp" }, null, null, null)
}));
}
#endregion

}
#region Messages
/// <summary>
/// Protocol buffer representing the checkpoint state.
/// </summary>
public sealed partial class CheckpointState : pb::IMessage<CheckpointState> {
private static readonly pb::MessageParser<CheckpointState> _parser = new pb::MessageParser<CheckpointState>(() => new CheckpointState());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<CheckpointState> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.CheckpointStateReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public CheckpointState() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public CheckpointState(CheckpointState other) : this() {
modelCheckpointPath_ = other.modelCheckpointPath_;
allModelCheckpointPaths_ = other.allModelCheckpointPaths_.Clone();
allModelCheckpointTimestamps_ = other.allModelCheckpointTimestamps_.Clone();
lastPreservedTimestamp_ = other.lastPreservedTimestamp_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public CheckpointState Clone() {
return new CheckpointState(this);
}

/// <summary>Field number for the "model_checkpoint_path" field.</summary>
public const int ModelCheckpointPathFieldNumber = 1;
private string modelCheckpointPath_ = "";
/// <summary>
/// Path to the most-recent model checkpoint.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string ModelCheckpointPath {
get { return modelCheckpointPath_; }
set {
modelCheckpointPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "all_model_checkpoint_paths" field.</summary>
public const int AllModelCheckpointPathsFieldNumber = 2;
private static readonly pb::FieldCodec<string> _repeated_allModelCheckpointPaths_codec
= pb::FieldCodec.ForString(18);
private readonly pbc::RepeatedField<string> allModelCheckpointPaths_ = new pbc::RepeatedField<string>();
/// <summary>
/// Paths to all not-yet-deleted model checkpoints, sorted from oldest to
/// newest.
/// Note that the value of model_checkpoint_path should be the last item in
/// this list.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<string> AllModelCheckpointPaths {
get { return allModelCheckpointPaths_; }
}

/// <summary>Field number for the "all_model_checkpoint_timestamps" field.</summary>
public const int AllModelCheckpointTimestampsFieldNumber = 3;
private static readonly pb::FieldCodec<double> _repeated_allModelCheckpointTimestamps_codec
= pb::FieldCodec.ForDouble(26);
private readonly pbc::RepeatedField<double> allModelCheckpointTimestamps_ = new pbc::RepeatedField<double>();
/// <summary>
/// Unix timestamps corresponding to all_model_checkpoint_paths, indicating
/// when each checkpoint was created.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<double> AllModelCheckpointTimestamps {
get { return allModelCheckpointTimestamps_; }
}

/// <summary>Field number for the "last_preserved_timestamp" field.</summary>
public const int LastPreservedTimestampFieldNumber = 4;
private double lastPreservedTimestamp_;
/// <summary>
/// Unix timestamp indicating the creation time for the last preserved
/// checkpoint.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public double LastPreservedTimestamp {
get { return lastPreservedTimestamp_; }
set {
lastPreservedTimestamp_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as CheckpointState);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(CheckpointState other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (ModelCheckpointPath != other.ModelCheckpointPath) return false;
if(!allModelCheckpointPaths_.Equals(other.allModelCheckpointPaths_)) return false;
if(!allModelCheckpointTimestamps_.Equals(other.allModelCheckpointTimestamps_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(LastPreservedTimestamp, other.LastPreservedTimestamp)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode();
hash ^= allModelCheckpointPaths_.GetHashCode();
hash ^= allModelCheckpointTimestamps_.GetHashCode();
if (LastPreservedTimestamp != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(LastPreservedTimestamp);
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (ModelCheckpointPath.Length != 0) {
output.WriteRawTag(10);
output.WriteString(ModelCheckpointPath);
}
allModelCheckpointPaths_.WriteTo(output, _repeated_allModelCheckpointPaths_codec);
allModelCheckpointTimestamps_.WriteTo(output, _repeated_allModelCheckpointTimestamps_codec);
if (LastPreservedTimestamp != 0D) {
output.WriteRawTag(33);
output.WriteDouble(LastPreservedTimestamp);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (ModelCheckpointPath.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(ModelCheckpointPath);
}
size += allModelCheckpointPaths_.CalculateSize(_repeated_allModelCheckpointPaths_codec);
size += allModelCheckpointTimestamps_.CalculateSize(_repeated_allModelCheckpointTimestamps_codec);
if (LastPreservedTimestamp != 0D) {
size += 1 + 8;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(CheckpointState other) {
if (other == null) {
return;
}
if (other.ModelCheckpointPath.Length != 0) {
ModelCheckpointPath = other.ModelCheckpointPath;
}
allModelCheckpointPaths_.Add(other.allModelCheckpointPaths_);
allModelCheckpointTimestamps_.Add(other.allModelCheckpointTimestamps_);
if (other.LastPreservedTimestamp != 0D) {
LastPreservedTimestamp = other.LastPreservedTimestamp;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
ModelCheckpointPath = input.ReadString();
break;
}
case 18: {
allModelCheckpointPaths_.AddEntriesFrom(input, _repeated_allModelCheckpointPaths_codec);
break;
}
case 26:
case 25: {
allModelCheckpointTimestamps_.AddEntriesFrom(input, _repeated_allModelCheckpointTimestamps_codec);
break;
}
case 33: {
LastPreservedTimestamp = input.ReadDouble();
break;
}
}
}
}

}

#endregion

}

#endregion Designer generated code

+ 2679
- 0
src/TensorFlowNET.Core/Protobuf/MetaGraph.cs
File diff suppressed because it is too large
View File


+ 13
- 15
src/TensorFlowNET.Core/Protobuf/README.md View File

@@ -1,20 +1,18 @@
### Download compiler from https://github.com/protocolbuffers/protobuf/releases ### Download compiler from https://github.com/protocolbuffers/protobuf/releases
```shell ```shell
set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework
set SRC_DIR=D:\Projects\tensorflow
set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf


protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% versions.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% function.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% graph.proto
```

```shell
set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\protobuf
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% saver.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\resource_handle.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor_shape.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\types.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\attr_value.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\node_def.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\versions.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\function.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\graph.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\saver.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\meta_graph.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\python\training\checkpoint_state.proto
``` ```

+ 5
- 0
src/TensorFlowNET.Core/Python.cs View File

@@ -81,6 +81,11 @@ namespace Tensorflow
} }
} }


public static float time()
{
return (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds;
}

public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2) public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2)
{ {
int index = 0; int index = 0;


+ 117
- 3
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq;
using System.Text; using System.Text;


namespace Tensorflow namespace Tensorflow
@@ -28,6 +29,8 @@ namespace Tensorflow
private float _next_checkpoint_time; private float _next_checkpoint_time;
private bool _save_relative_paths; private bool _save_relative_paths;
private bool? _object_restore_saver; private bool? _object_restore_saver;
private Dictionary<string, float> _last_checkpoints;
private Dictionary<string, float> _checkpoints_to_be_deleted;


public Saver(RefVariable[] var_list = null, public Saver(RefVariable[] var_list = null,
bool reshape = false, bool reshape = false,
@@ -68,6 +71,9 @@ namespace Tensorflow


_save_relative_paths = save_relative_paths; _save_relative_paths = save_relative_paths;
_object_restore_saver = null; _object_restore_saver = null;

_last_checkpoints = new Dictionary<string, float>();
_checkpoints_to_be_deleted = new Dictionary<string, float>();
} }


public void build() public void build()
@@ -121,7 +127,7 @@ namespace Tensorflow


_check_saver_def(); _check_saver_def();


_next_checkpoint_time = (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds + _saver_def.KeepCheckpointEveryNHours * 3600;
_next_checkpoint_time = Python.time() + _saver_def.KeepCheckpointEveryNHours * 3600;
} }


private void _check_saver_def() private void _check_saver_def()
@@ -165,11 +171,119 @@ namespace Tensorflow
model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) new FeedItem(_saver_def.FilenameTensorName, checkpoint_file)
}); });

if (write_state)
{
_RecordLastCheckpoint(model_checkpoint_path);
checkpoint_management.update_checkpoint_state_internal(
save_dir: save_path_parent,
model_checkpoint_path: model_checkpoint_path,
all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(),
latest_filename: latest_filename,
save_relative_paths: _save_relative_paths);
_MaybeDeleteOldCheckpoints(meta_graph_suffix: meta_graph_suffix);
}
}

if (write_meta_graph)
{
string meta_graph_filename = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix: meta_graph_suffix);
}

return _is_empty ? string.Empty : model_checkpoint_path;
}

/// <summary>
/// Writes `MetaGraphDef` to save_path/filename.
/// </summary>
/// <param name="filename"></param>
/// <param name="collection_list"></param>
/// <param name="as_text"></param>
/// <param name="export_scope"></param>
/// <param name="clear_devices"></param>
/// <param name="clear_extraneous_savers"></param>
/// <param name="strip_default_attrs"></param>
public MetaGraphDef export_meta_graph(string filename= "",
string[] collection_list = null,
string export_scope = "",
bool as_text= false,
bool clear_devices= false,
bool clear_extraneous_savers= false,
bool strip_default_attrs= false)
{
return export_meta_graph(
filename: filename,
graph_def: ops.get_default_graph()._as_graph_def(add_shapes: true),
saver_def: _saver_def,
collection_list: collection_list,
as_text: as_text,
export_scope: export_scope,
clear_devices: clear_devices,
clear_extraneous_savers: clear_extraneous_savers,
strip_default_attrs: strip_default_attrs);
}

public MetaGraphDef export_meta_graph(string filename = "",
byte[] meta_info_def = null,
GraphDef graph_def = null,
SaverDef saver_def = null,
string[] collection_list = null,
bool as_text = false,
bool clear_devices= false,
bool clear_extraneous_savers= false,
bool strip_default_attrs= false,
string export_scope = "")
{
var meta_graph_def = meta_graph.export_scoped_meta_graph(
filename: filename,
meta_info_def: meta_info_def,
graph_def: graph_def,
saver_def: saver_def,
// collection_list: collection_list,
as_text: as_text,
clear_devices: clear_devices,
clear_extraneous_savers: clear_extraneous_savers,
strip_default_attrs: strip_default_attrs);
return meta_graph_def;
}

/// <summary>
/// Manages the list of the latest checkpoints.
/// </summary>
/// <param name="latest_save_path"></param>
private void _RecordLastCheckpoint(string latest_save_path)
{
if (_saver_def.MaxToKeep <= 0) return;

// Remove first from list if the same name was used before.
foreach (var p in _last_checkpoints)
if (latest_save_path == _CheckpointFilename((p.Key, p.Value)))
_last_checkpoints.Remove(p.Key);

// Append new path to list
_last_checkpoints.Add(latest_save_path, Python.time());

// If more than max_to_keep, remove oldest.
if(_last_checkpoints.Count > _saver_def.MaxToKeep)
{
var first = _last_checkpoints.First();
_last_checkpoints.Remove(first.Key);
_checkpoints_to_be_deleted[first.Key] = first.Value;
} }
}


throw new NotImplementedException("Saver.save");
private string _CheckpointFilename((string, float) p)
{
return p.Item1;
}

/// <summary>
/// Deletes old checkpoints if necessary.
/// </summary>
/// <param name="meta_graph_suffix"></param>
private void _MaybeDeleteOldCheckpoints(string meta_graph_suffix = "meta")
{


return model_checkpoint_path;
} }
} }
} }

+ 109
- 0
src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs View File

@@ -0,0 +1,109 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;

namespace Tensorflow
{
public class checkpoint_management
{
/// <summary>
/// Updates the content of the 'checkpoint' file.
/// </summary>
/// <param name="save_dir">Directory where the model was saved.</param>
/// <param name="model_checkpoint_path">The checkpoint file.</param>
/// <param name="all_model_checkpoint_paths">List of strings.</param>
/// <param name="latest_filename"></param>
/// <param name="save_relative_paths"></param>
/// <param name="all_model_checkpoint_timestamps"></param>
/// <param name="last_preserved_timestamp"></param>
public static void update_checkpoint_state_internal(string save_dir,
string model_checkpoint_path,
List<string> all_model_checkpoint_paths = null,
string latest_filename = "",
bool save_relative_paths = false,
List<float> all_model_checkpoint_timestamps = null,
float? last_preserved_timestamp = null
)
{
CheckpointState ckpt = null;

// Writes the "checkpoint" file for the coordinator for later restoration.
string coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename);
if (save_relative_paths)
{
throw new NotImplementedException("update_checkpoint_state_internal save_relative_paths");
}
else
{
ckpt = generate_checkpoint_state_proto(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths,
all_model_checkpoint_timestamps,
last_preserved_timestamp);
}

if (coord_checkpoint_filename == ckpt.ModelCheckpointPath)
throw new RuntimeError($"Save path '{model_checkpoint_path}' conflicts with path used for " +
"checkpoint state. Please use a different save path.");

File.WriteAllText(coord_checkpoint_filename, ckpt.ToString());
}

/// <summary>
/// Returns a filename for storing the CheckpointState.
/// </summary>
/// <param name="save_dir">The directory for saving and restoring checkpoints.</param>
/// <param name="latest_filename">
/// Name of the file in 'save_dir' that is used
/// to store the CheckpointState.
/// </param>
/// <returns>he path of the file that contains the CheckpointState proto.</returns>
private static string _GetCheckpointFilename(string save_dir, string latest_filename)
{
if (string.IsNullOrEmpty(latest_filename))
latest_filename = "checkpoint";

return Path.Combine(save_dir, latest_filename);
}

private static CheckpointState generate_checkpoint_state_proto(string save_dir,
string model_checkpoint_path,
List<string> all_model_checkpoint_paths = null,
List<float> all_model_checkpoint_timestamps = null,
float? last_preserved_timestamp = null)
{
if (all_model_checkpoint_paths == null)
all_model_checkpoint_paths = new List<string>();

// Relative paths need to be rewritten to be relative to the "save_dir"
// if model_checkpoint_path already contains "save_dir".
all_model_checkpoint_paths.Add(model_checkpoint_path);

var coord_checkpoint_proto = new CheckpointState()
{
ModelCheckpointPath = model_checkpoint_path,
LastPreservedTimestamp = last_preserved_timestamp.Value
};

coord_checkpoint_proto.AllModelCheckpointPaths.AddRange(all_model_checkpoint_paths);
coord_checkpoint_proto.AllModelCheckpointTimestamps.AddRange(all_model_checkpoint_timestamps.Select(x => (double)x));

return coord_checkpoint_proto;
}

/// <summary>
/// Returns the meta graph filename.
/// </summary>
/// <param name="checkpoint_filename"></param>
/// <param name="meta_graph_suffix"></param>
/// <returns></returns>
public static string meta_graph_filename(string checkpoint_filename, string meta_graph_suffix= "meta")
{
string basename = checkpoint_filename;
string suffixed_filename = basename + "." + meta_graph_suffix;
return suffixed_filename;
}
}
}

+ 2
- 1
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -20,7 +20,8 @@ namespace TensorFlowNET.UnitTest
{ {
var handle = c_api.TF_GetAllOpList(); var handle = c_api.TF_GetAllOpList();
var buffer = new Buffer(handle); var buffer = new Buffer(handle);
Assert.IsTrue(buffer.Length == buffer.Length);
var op_list = OpList.Parser.ParseFrom(buffer);
Assert.IsTrue(op_list.Op.Count > 1000);
} }


[TestMethod] [TestMethod]


Loading…
Cancel
Save