add RuntimeError exception class. add meta_graph related classes.tags/v0.8.0
| @@ -0,0 +1,19 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class RuntimeError : Exception | |||
| { | |||
| public RuntimeError() : base() | |||
| { | |||
| } | |||
| public RuntimeError(string message) : base(message) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,135 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using static Tensorflow.MetaGraphDef.Types; | |||
| namespace Tensorflow | |||
| { | |||
| public class meta_graph | |||
| { | |||
| /// <summary> | |||
| /// Returns `MetaGraphDef` proto. Optionally writes it to filename. | |||
| /// </summary> | |||
| /// <param name="filename"></param> | |||
| /// <param name="graph_def"></param> | |||
| /// <param name="as_text"></param> | |||
| /// <param name="unbound_inputs_col_name"></param> | |||
| /// <param name="clear_devices"></param> | |||
| /// <param name="saver_def"></param> | |||
| /// <param name="clear_extraneous_savers"></param> | |||
| /// <param name="strip_default_attrs"></param> | |||
| /// <param name="meta_info_def"></param> | |||
| /// <returns></returns> | |||
| public static MetaGraphDef export_scoped_meta_graph(string filename = "", | |||
| GraphDef graph_def = null, | |||
| bool as_text = false, | |||
| string unbound_inputs_col_name = "unbound_inputs", | |||
| bool clear_devices = false, | |||
| SaverDef saver_def = null, | |||
| bool clear_extraneous_savers= false, | |||
| bool strip_default_attrs= false, | |||
| byte[] meta_info_def = null) | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| var var_list = new Dictionary<string, RefVariable>(); | |||
| var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES); | |||
| foreach(var v in variables as RefVariable[]) | |||
| { | |||
| var_list[v.name] = v; | |||
| } | |||
| var scoped_meta_graph_def = create_meta_graph_def( | |||
| graph_def: graph_def, | |||
| export_scope: "", | |||
| exclude_nodes: "", | |||
| clear_extraneous_savers: clear_extraneous_savers, | |||
| saver_def: saver_def, | |||
| strip_default_attrs: strip_default_attrs); | |||
| throw new NotImplementedException("meta_graph.export_scoped_meta_graph"); | |||
| } | |||
| private static bool _should_include_node() | |||
| { | |||
| return true; | |||
| } | |||
| private static byte[] create_meta_graph_def(MetaInfoDef meta_info_def = null, | |||
| GraphDef graph_def = null, | |||
| string export_scope = "", | |||
| string exclude_nodes = "", | |||
| SaverDef saver_def = null, | |||
| bool clear_extraneous_savers = false, | |||
| bool strip_default_attrs = false) | |||
| { | |||
| // Sets graph to default graph if it's not passed in. | |||
| var graph = ops.get_default_graph(); | |||
| // Creates a MetaGraphDef proto. | |||
| var meta_graph_def = new MetaGraphDef(); | |||
| if (meta_info_def == null) | |||
| meta_info_def = new MetaInfoDef(); | |||
| // Set the tf version strings to the current tf build. | |||
| meta_info_def.TensorflowVersion = tf.VERSION; | |||
| meta_info_def.TensorflowGitVersion = "unknown"; | |||
| meta_graph_def.MetaInfoDef = meta_info_def; | |||
| // Adds graph_def or the default. | |||
| if (graph_def == null) | |||
| meta_graph_def.GraphDef = graph._as_graph_def(add_shapes: true); | |||
| else | |||
| meta_graph_def.GraphDef = graph_def; | |||
| // Fills in meta_info_def.stripped_op_list using the ops from graph_def. | |||
| if (meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0) | |||
| meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef); | |||
| throw new NotImplementedException("create_meta_graph_def"); | |||
| } | |||
| private static OpList stripped_op_list_for_graph(GraphDef graph_def) | |||
| { | |||
| var used_ops = ops_used_by_graph_def(graph_def); | |||
| // Verify that all used ops are registered. | |||
| // var registered_ops = op_def_registry.get_registered_ops(); | |||
| var op_list = new OpList(); | |||
| /*used_ops.OrderBy(x => x).Select(x => { | |||
| }).ToArray();*/ | |||
| return op_list; | |||
| } | |||
| /// <summary> | |||
| /// Collect the list of ops used by a graph. | |||
| /// </summary> | |||
| /// <param name="graph_def"></param> | |||
| /// <returns></returns> | |||
| private static string[] ops_used_by_graph_def(GraphDef graph_def) | |||
| { | |||
| var used_ops = new List<string>(); | |||
| Action<string> mark_op_as_used = (op) => | |||
| { | |||
| if (!used_ops.Contains(op)) | |||
| { | |||
| } | |||
| used_ops.Add(op); | |||
| }; | |||
| foreach (var node in graph_def.Node) | |||
| { | |||
| mark_op_as_used(node.Op); | |||
| } | |||
| return used_ops.ToArray(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,27 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class op_def_registry | |||
| { | |||
| private static Dictionary<string, OpDef> _registered_ops; | |||
| public static Dictionary<string, OpDef> get_registered_ops() | |||
| { | |||
| if(_registered_ops == null) | |||
| { | |||
| _registered_ops = new Dictionary<string, OpDef>(); | |||
| var handle = c_api.TF_GetAllOpList(); | |||
| var buffer = new Buffer(handle); | |||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||
| foreach (var op_def in op_list.Op) | |||
| _registered_ops[op_def.Name] = op_def; | |||
| } | |||
| return _registered_ops; | |||
| } | |||
| } | |||
| } | |||
| @@ -18,7 +18,7 @@ namespace Tensorflow | |||
| return buffer; | |||
| } | |||
| public GraphDef _as_graph_def() | |||
| public GraphDef _as_graph_def(bool add_shapes = false) | |||
| { | |||
| var buffer = ToGraphDef(Status); | |||
| Status.Check(); | |||
| @@ -0,0 +1,264 @@ | |||
| // <auto-generated> | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: checkpoint_state.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using pbr = global::Google.Protobuf.Reflection; | |||
| using scg = global::System.Collections.Generic; | |||
| namespace Tensorflow { | |||
| /// <summary>Holder for reflection information generated from checkpoint_state.proto</summary> | |||
| public static partial class CheckpointStateReflection { | |||
| #region Descriptor | |||
| /// <summary>File descriptor for checkpoint_state.proto</summary> | |||
| public static pbr::FileDescriptor Descriptor { | |||
| get { return descriptor; } | |||
| } | |||
| private static pbr::FileDescriptor descriptor; | |||
| static CheckpointStateReflection() { | |||
| byte[] descriptorData = global::System.Convert.FromBase64String( | |||
| string.Concat( | |||
| "ChZjaGVja3BvaW50X3N0YXRlLnByb3RvEgp0ZW5zb3JmbG93Ip8BCg9DaGVj", | |||
| "a3BvaW50U3RhdGUSHQoVbW9kZWxfY2hlY2twb2ludF9wYXRoGAEgASgJEiIK", | |||
| "GmFsbF9tb2RlbF9jaGVja3BvaW50X3BhdGhzGAIgAygJEicKH2FsbF9tb2Rl", | |||
| "bF9jaGVja3BvaW50X3RpbWVzdGFtcHMYAyADKAESIAoYbGFzdF9wcmVzZXJ2", | |||
| "ZWRfdGltZXN0YW1wGAQgASgBQgP4AQFiBnByb3RvMw==")); | |||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
| new pbr::FileDescriptor[] { }, | |||
| new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CheckpointState), global::Tensorflow.CheckpointState.Parser, new[]{ "ModelCheckpointPath", "AllModelCheckpointPaths", "AllModelCheckpointTimestamps", "LastPreservedTimestamp" }, null, null, null) | |||
| })); | |||
| } | |||
| #endregion | |||
| } | |||
| #region Messages | |||
| /// <summary> | |||
| /// Protocol buffer representing the checkpoint state. | |||
| /// </summary> | |||
| public sealed partial class CheckpointState : pb::IMessage<CheckpointState> { | |||
| private static readonly pb::MessageParser<CheckpointState> _parser = new pb::MessageParser<CheckpointState>(() => new CheckpointState()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pb::MessageParser<CheckpointState> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.CheckpointStateReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public CheckpointState() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public CheckpointState(CheckpointState other) : this() { | |||
| modelCheckpointPath_ = other.modelCheckpointPath_; | |||
| allModelCheckpointPaths_ = other.allModelCheckpointPaths_.Clone(); | |||
| allModelCheckpointTimestamps_ = other.allModelCheckpointTimestamps_.Clone(); | |||
| lastPreservedTimestamp_ = other.lastPreservedTimestamp_; | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public CheckpointState Clone() { | |||
| return new CheckpointState(this); | |||
| } | |||
| /// <summary>Field number for the "model_checkpoint_path" field.</summary> | |||
| public const int ModelCheckpointPathFieldNumber = 1; | |||
| private string modelCheckpointPath_ = ""; | |||
| /// <summary> | |||
| /// Path to the most-recent model checkpoint. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public string ModelCheckpointPath { | |||
| get { return modelCheckpointPath_; } | |||
| set { | |||
| modelCheckpointPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
| } | |||
| } | |||
| /// <summary>Field number for the "all_model_checkpoint_paths" field.</summary> | |||
| public const int AllModelCheckpointPathsFieldNumber = 2; | |||
| private static readonly pb::FieldCodec<string> _repeated_allModelCheckpointPaths_codec | |||
| = pb::FieldCodec.ForString(18); | |||
| private readonly pbc::RepeatedField<string> allModelCheckpointPaths_ = new pbc::RepeatedField<string>(); | |||
| /// <summary> | |||
| /// Paths to all not-yet-deleted model checkpoints, sorted from oldest to | |||
| /// newest. | |||
| /// Note that the value of model_checkpoint_path should be the last item in | |||
| /// this list. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<string> AllModelCheckpointPaths { | |||
| get { return allModelCheckpointPaths_; } | |||
| } | |||
| /// <summary>Field number for the "all_model_checkpoint_timestamps" field.</summary> | |||
| public const int AllModelCheckpointTimestampsFieldNumber = 3; | |||
| private static readonly pb::FieldCodec<double> _repeated_allModelCheckpointTimestamps_codec | |||
| = pb::FieldCodec.ForDouble(26); | |||
| private readonly pbc::RepeatedField<double> allModelCheckpointTimestamps_ = new pbc::RepeatedField<double>(); | |||
| /// <summary> | |||
| /// Unix timestamps corresponding to all_model_checkpoint_paths, indicating | |||
| /// when each checkpoint was created. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<double> AllModelCheckpointTimestamps { | |||
| get { return allModelCheckpointTimestamps_; } | |||
| } | |||
| /// <summary>Field number for the "last_preserved_timestamp" field.</summary> | |||
| public const int LastPreservedTimestampFieldNumber = 4; | |||
| private double lastPreservedTimestamp_; | |||
| /// <summary> | |||
| /// Unix timestamp indicating the creation time for the last preserved | |||
| /// checkpoint. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public double LastPreservedTimestamp { | |||
| get { return lastPreservedTimestamp_; } | |||
| set { | |||
| lastPreservedTimestamp_ = value; | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as CheckpointState); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool Equals(CheckpointState other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if (ModelCheckpointPath != other.ModelCheckpointPath) return false; | |||
| if(!allModelCheckpointPaths_.Equals(other.allModelCheckpointPaths_)) return false; | |||
| if(!allModelCheckpointTimestamps_.Equals(other.allModelCheckpointTimestamps_)) return false; | |||
| if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(LastPreservedTimestamp, other.LastPreservedTimestamp)) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode(); | |||
| hash ^= allModelCheckpointPaths_.GetHashCode(); | |||
| hash ^= allModelCheckpointTimestamps_.GetHashCode(); | |||
| if (LastPreservedTimestamp != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(LastPreservedTimestamp); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| if (ModelCheckpointPath.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(ModelCheckpointPath); | |||
| } | |||
| allModelCheckpointPaths_.WriteTo(output, _repeated_allModelCheckpointPaths_codec); | |||
| allModelCheckpointTimestamps_.WriteTo(output, _repeated_allModelCheckpointTimestamps_codec); | |||
| if (LastPreservedTimestamp != 0D) { | |||
| output.WriteRawTag(33); | |||
| output.WriteDouble(LastPreservedTimestamp); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (ModelCheckpointPath.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(ModelCheckpointPath); | |||
| } | |||
| size += allModelCheckpointPaths_.CalculateSize(_repeated_allModelCheckpointPaths_codec); | |||
| size += allModelCheckpointTimestamps_.CalculateSize(_repeated_allModelCheckpointTimestamps_codec); | |||
| if (LastPreservedTimestamp != 0D) { | |||
| size += 1 + 8; | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(CheckpointState other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| if (other.ModelCheckpointPath.Length != 0) { | |||
| ModelCheckpointPath = other.ModelCheckpointPath; | |||
| } | |||
| allModelCheckpointPaths_.Add(other.allModelCheckpointPaths_); | |||
| allModelCheckpointTimestamps_.Add(other.allModelCheckpointTimestamps_); | |||
| if (other.LastPreservedTimestamp != 0D) { | |||
| LastPreservedTimestamp = other.LastPreservedTimestamp; | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 10: { | |||
| ModelCheckpointPath = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| allModelCheckpointPaths_.AddEntriesFrom(input, _repeated_allModelCheckpointPaths_codec); | |||
| break; | |||
| } | |||
| case 26: | |||
| case 25: { | |||
| allModelCheckpointTimestamps_.AddEntriesFrom(input, _repeated_allModelCheckpointTimestamps_codec); | |||
| break; | |||
| } | |||
| case 33: { | |||
| LastPreservedTimestamp = input.ReadDouble(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endregion | |||
| } | |||
| #endregion Designer generated code | |||
| @@ -1,20 +1,18 @@ | |||
| ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | |||
| ```shell | |||
| set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework | |||
| set SRC_DIR=D:\Projects\tensorflow | |||
| set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% versions.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% function.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% graph.proto | |||
| ``` | |||
| ```shell | |||
| set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\protobuf | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% saver.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\resource_handle.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor_shape.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\types.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\attr_value.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\node_def.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\versions.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\function.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\graph.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\saver.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\meta_graph.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\python\training\checkpoint_state.proto | |||
| ``` | |||
| @@ -81,6 +81,11 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public static float time() | |||
| { | |||
| return (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds; | |||
| } | |||
| public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2) | |||
| { | |||
| int index = 0; | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| @@ -28,6 +29,8 @@ namespace Tensorflow | |||
| private float _next_checkpoint_time; | |||
| private bool _save_relative_paths; | |||
| private bool? _object_restore_saver; | |||
| private Dictionary<string, float> _last_checkpoints; | |||
| private Dictionary<string, float> _checkpoints_to_be_deleted; | |||
| public Saver(RefVariable[] var_list = null, | |||
| bool reshape = false, | |||
| @@ -68,6 +71,9 @@ namespace Tensorflow | |||
| _save_relative_paths = save_relative_paths; | |||
| _object_restore_saver = null; | |||
| _last_checkpoints = new Dictionary<string, float>(); | |||
| _checkpoints_to_be_deleted = new Dictionary<string, float>(); | |||
| } | |||
| public void build() | |||
| @@ -121,7 +127,7 @@ namespace Tensorflow | |||
| _check_saver_def(); | |||
| _next_checkpoint_time = (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds + _saver_def.KeepCheckpointEveryNHours * 3600; | |||
| _next_checkpoint_time = Python.time() + _saver_def.KeepCheckpointEveryNHours * 3600; | |||
| } | |||
| private void _check_saver_def() | |||
| @@ -165,11 +171,119 @@ namespace Tensorflow | |||
| model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { | |||
| new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) | |||
| }); | |||
| if (write_state) | |||
| { | |||
| _RecordLastCheckpoint(model_checkpoint_path); | |||
| checkpoint_management.update_checkpoint_state_internal( | |||
| save_dir: save_path_parent, | |||
| model_checkpoint_path: model_checkpoint_path, | |||
| all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(), | |||
| latest_filename: latest_filename, | |||
| save_relative_paths: _save_relative_paths); | |||
| _MaybeDeleteOldCheckpoints(meta_graph_suffix: meta_graph_suffix); | |||
| } | |||
| } | |||
| if (write_meta_graph) | |||
| { | |||
| string meta_graph_filename = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix: meta_graph_suffix); | |||
| } | |||
| return _is_empty ? string.Empty : model_checkpoint_path; | |||
| } | |||
| /// <summary> | |||
| /// Writes `MetaGraphDef` to save_path/filename. | |||
| /// </summary> | |||
| /// <param name="filename"></param> | |||
| /// <param name="collection_list"></param> | |||
| /// <param name="as_text"></param> | |||
| /// <param name="export_scope"></param> | |||
| /// <param name="clear_devices"></param> | |||
| /// <param name="clear_extraneous_savers"></param> | |||
| /// <param name="strip_default_attrs"></param> | |||
| public MetaGraphDef export_meta_graph(string filename= "", | |||
| string[] collection_list = null, | |||
| string export_scope = "", | |||
| bool as_text= false, | |||
| bool clear_devices= false, | |||
| bool clear_extraneous_savers= false, | |||
| bool strip_default_attrs= false) | |||
| { | |||
| return export_meta_graph( | |||
| filename: filename, | |||
| graph_def: ops.get_default_graph()._as_graph_def(add_shapes: true), | |||
| saver_def: _saver_def, | |||
| collection_list: collection_list, | |||
| as_text: as_text, | |||
| export_scope: export_scope, | |||
| clear_devices: clear_devices, | |||
| clear_extraneous_savers: clear_extraneous_savers, | |||
| strip_default_attrs: strip_default_attrs); | |||
| } | |||
| public MetaGraphDef export_meta_graph(string filename = "", | |||
| byte[] meta_info_def = null, | |||
| GraphDef graph_def = null, | |||
| SaverDef saver_def = null, | |||
| string[] collection_list = null, | |||
| bool as_text = false, | |||
| bool clear_devices= false, | |||
| bool clear_extraneous_savers= false, | |||
| bool strip_default_attrs= false, | |||
| string export_scope = "") | |||
| { | |||
| var meta_graph_def = meta_graph.export_scoped_meta_graph( | |||
| filename: filename, | |||
| meta_info_def: meta_info_def, | |||
| graph_def: graph_def, | |||
| saver_def: saver_def, | |||
| // collection_list: collection_list, | |||
| as_text: as_text, | |||
| clear_devices: clear_devices, | |||
| clear_extraneous_savers: clear_extraneous_savers, | |||
| strip_default_attrs: strip_default_attrs); | |||
| return meta_graph_def; | |||
| } | |||
| /// <summary> | |||
| /// Manages the list of the latest checkpoints. | |||
| /// </summary> | |||
| /// <param name="latest_save_path"></param> | |||
| private void _RecordLastCheckpoint(string latest_save_path) | |||
| { | |||
| if (_saver_def.MaxToKeep <= 0) return; | |||
| // Remove first from list if the same name was used before. | |||
| foreach (var p in _last_checkpoints) | |||
| if (latest_save_path == _CheckpointFilename((p.Key, p.Value))) | |||
| _last_checkpoints.Remove(p.Key); | |||
| // Append new path to list | |||
| _last_checkpoints.Add(latest_save_path, Python.time()); | |||
| // If more than max_to_keep, remove oldest. | |||
| if(_last_checkpoints.Count > _saver_def.MaxToKeep) | |||
| { | |||
| var first = _last_checkpoints.First(); | |||
| _last_checkpoints.Remove(first.Key); | |||
| _checkpoints_to_be_deleted[first.Key] = first.Value; | |||
| } | |||
| } | |||
| throw new NotImplementedException("Saver.save"); | |||
| private string _CheckpointFilename((string, float) p) | |||
| { | |||
| return p.Item1; | |||
| } | |||
| /// <summary> | |||
| /// Deletes old checkpoints if necessary. | |||
| /// </summary> | |||
| /// <param name="meta_graph_suffix"></param> | |||
| private void _MaybeDeleteOldCheckpoints(string meta_graph_suffix = "meta") | |||
| { | |||
| return model_checkpoint_path; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,109 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class checkpoint_management | |||
| { | |||
| /// <summary> | |||
| /// Updates the content of the 'checkpoint' file. | |||
| /// </summary> | |||
| /// <param name="save_dir">Directory where the model was saved.</param> | |||
| /// <param name="model_checkpoint_path">The checkpoint file.</param> | |||
| /// <param name="all_model_checkpoint_paths">List of strings.</param> | |||
| /// <param name="latest_filename"></param> | |||
| /// <param name="save_relative_paths"></param> | |||
| /// <param name="all_model_checkpoint_timestamps"></param> | |||
| /// <param name="last_preserved_timestamp"></param> | |||
| public static void update_checkpoint_state_internal(string save_dir, | |||
| string model_checkpoint_path, | |||
| List<string> all_model_checkpoint_paths = null, | |||
| string latest_filename = "", | |||
| bool save_relative_paths = false, | |||
| List<float> all_model_checkpoint_timestamps = null, | |||
| float? last_preserved_timestamp = null | |||
| ) | |||
| { | |||
| CheckpointState ckpt = null; | |||
| // Writes the "checkpoint" file for the coordinator for later restoration. | |||
| string coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename); | |||
| if (save_relative_paths) | |||
| { | |||
| throw new NotImplementedException("update_checkpoint_state_internal save_relative_paths"); | |||
| } | |||
| else | |||
| { | |||
| ckpt = generate_checkpoint_state_proto(save_dir, | |||
| model_checkpoint_path, | |||
| all_model_checkpoint_paths, | |||
| all_model_checkpoint_timestamps, | |||
| last_preserved_timestamp); | |||
| } | |||
| if (coord_checkpoint_filename == ckpt.ModelCheckpointPath) | |||
| throw new RuntimeError($"Save path '{model_checkpoint_path}' conflicts with path used for " + | |||
| "checkpoint state. Please use a different save path."); | |||
| File.WriteAllText(coord_checkpoint_filename, ckpt.ToString()); | |||
| } | |||
| /// <summary> | |||
| /// Returns a filename for storing the CheckpointState. | |||
| /// </summary> | |||
| /// <param name="save_dir">The directory for saving and restoring checkpoints.</param> | |||
| /// <param name="latest_filename"> | |||
| /// Name of the file in 'save_dir' that is used | |||
| /// to store the CheckpointState. | |||
| /// </param> | |||
| /// <returns>he path of the file that contains the CheckpointState proto.</returns> | |||
| private static string _GetCheckpointFilename(string save_dir, string latest_filename) | |||
| { | |||
| if (string.IsNullOrEmpty(latest_filename)) | |||
| latest_filename = "checkpoint"; | |||
| return Path.Combine(save_dir, latest_filename); | |||
| } | |||
| private static CheckpointState generate_checkpoint_state_proto(string save_dir, | |||
| string model_checkpoint_path, | |||
| List<string> all_model_checkpoint_paths = null, | |||
| List<float> all_model_checkpoint_timestamps = null, | |||
| float? last_preserved_timestamp = null) | |||
| { | |||
| if (all_model_checkpoint_paths == null) | |||
| all_model_checkpoint_paths = new List<string>(); | |||
| // Relative paths need to be rewritten to be relative to the "save_dir" | |||
| // if model_checkpoint_path already contains "save_dir". | |||
| all_model_checkpoint_paths.Add(model_checkpoint_path); | |||
| var coord_checkpoint_proto = new CheckpointState() | |||
| { | |||
| ModelCheckpointPath = model_checkpoint_path, | |||
| LastPreservedTimestamp = last_preserved_timestamp.Value | |||
| }; | |||
| coord_checkpoint_proto.AllModelCheckpointPaths.AddRange(all_model_checkpoint_paths); | |||
| coord_checkpoint_proto.AllModelCheckpointTimestamps.AddRange(all_model_checkpoint_timestamps.Select(x => (double)x)); | |||
| return coord_checkpoint_proto; | |||
| } | |||
| /// <summary> | |||
| /// Returns the meta graph filename. | |||
| /// </summary> | |||
| /// <param name="checkpoint_filename"></param> | |||
| /// <param name="meta_graph_suffix"></param> | |||
| /// <returns></returns> | |||
| public static string meta_graph_filename(string checkpoint_filename, string meta_graph_suffix= "meta") | |||
| { | |||
| string basename = checkpoint_filename; | |||
| string suffixed_filename = basename + "." + meta_graph_suffix; | |||
| return suffixed_filename; | |||
| } | |||
| } | |||
| } | |||
| @@ -20,7 +20,8 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| var handle = c_api.TF_GetAllOpList(); | |||
| var buffer = new Buffer(handle); | |||
| Assert.IsTrue(buffer.Length == buffer.Length); | |||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||
| Assert.IsTrue(op_list.Op.Count > 1000); | |||
| } | |||
| [TestMethod] | |||