add RuntimeError exception class. add meta_graph related classes.tags/v0.8.0
| @@ -0,0 +1,19 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class RuntimeError : Exception | |||||
| { | |||||
| public RuntimeError() : base() | |||||
| { | |||||
| } | |||||
| public RuntimeError(string message) : base(message) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,135 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using static Tensorflow.MetaGraphDef.Types; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class meta_graph | |||||
| { | |||||
| /// <summary> | |||||
| /// Returns `MetaGraphDef` proto. Optionally writes it to filename. | |||||
| /// </summary> | |||||
| /// <param name="filename"></param> | |||||
| /// <param name="graph_def"></param> | |||||
| /// <param name="as_text"></param> | |||||
| /// <param name="unbound_inputs_col_name"></param> | |||||
| /// <param name="clear_devices"></param> | |||||
| /// <param name="saver_def"></param> | |||||
| /// <param name="clear_extraneous_savers"></param> | |||||
| /// <param name="strip_default_attrs"></param> | |||||
| /// <param name="meta_info_def"></param> | |||||
| /// <returns></returns> | |||||
| public static MetaGraphDef export_scoped_meta_graph(string filename = "", | |||||
| GraphDef graph_def = null, | |||||
| bool as_text = false, | |||||
| string unbound_inputs_col_name = "unbound_inputs", | |||||
| bool clear_devices = false, | |||||
| SaverDef saver_def = null, | |||||
| bool clear_extraneous_savers= false, | |||||
| bool strip_default_attrs= false, | |||||
| byte[] meta_info_def = null) | |||||
| { | |||||
| var graph = ops.get_default_graph(); | |||||
| var var_list = new Dictionary<string, RefVariable>(); | |||||
| var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES); | |||||
| foreach(var v in variables as RefVariable[]) | |||||
| { | |||||
| var_list[v.name] = v; | |||||
| } | |||||
| var scoped_meta_graph_def = create_meta_graph_def( | |||||
| graph_def: graph_def, | |||||
| export_scope: "", | |||||
| exclude_nodes: "", | |||||
| clear_extraneous_savers: clear_extraneous_savers, | |||||
| saver_def: saver_def, | |||||
| strip_default_attrs: strip_default_attrs); | |||||
| throw new NotImplementedException("meta_graph.export_scoped_meta_graph"); | |||||
| } | |||||
| private static bool _should_include_node() | |||||
| { | |||||
| return true; | |||||
| } | |||||
| private static byte[] create_meta_graph_def(MetaInfoDef meta_info_def = null, | |||||
| GraphDef graph_def = null, | |||||
| string export_scope = "", | |||||
| string exclude_nodes = "", | |||||
| SaverDef saver_def = null, | |||||
| bool clear_extraneous_savers = false, | |||||
| bool strip_default_attrs = false) | |||||
| { | |||||
| // Sets graph to default graph if it's not passed in. | |||||
| var graph = ops.get_default_graph(); | |||||
| // Creates a MetaGraphDef proto. | |||||
| var meta_graph_def = new MetaGraphDef(); | |||||
| if (meta_info_def == null) | |||||
| meta_info_def = new MetaInfoDef(); | |||||
| // Set the tf version strings to the current tf build. | |||||
| meta_info_def.TensorflowVersion = tf.VERSION; | |||||
| meta_info_def.TensorflowGitVersion = "unknown"; | |||||
| meta_graph_def.MetaInfoDef = meta_info_def; | |||||
| // Adds graph_def or the default. | |||||
| if (graph_def == null) | |||||
| meta_graph_def.GraphDef = graph._as_graph_def(add_shapes: true); | |||||
| else | |||||
| meta_graph_def.GraphDef = graph_def; | |||||
| // Fills in meta_info_def.stripped_op_list using the ops from graph_def. | |||||
| if (meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0) | |||||
| meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef); | |||||
| throw new NotImplementedException("create_meta_graph_def"); | |||||
| } | |||||
| private static OpList stripped_op_list_for_graph(GraphDef graph_def) | |||||
| { | |||||
| var used_ops = ops_used_by_graph_def(graph_def); | |||||
| // Verify that all used ops are registered. | |||||
| // var registered_ops = op_def_registry.get_registered_ops(); | |||||
| var op_list = new OpList(); | |||||
| /*used_ops.OrderBy(x => x).Select(x => { | |||||
| }).ToArray();*/ | |||||
| return op_list; | |||||
| } | |||||
| /// <summary> | |||||
| /// Collect the list of ops used by a graph. | |||||
| /// </summary> | |||||
| /// <param name="graph_def"></param> | |||||
| /// <returns></returns> | |||||
| private static string[] ops_used_by_graph_def(GraphDef graph_def) | |||||
| { | |||||
| var used_ops = new List<string>(); | |||||
| Action<string> mark_op_as_used = (op) => | |||||
| { | |||||
| if (!used_ops.Contains(op)) | |||||
| { | |||||
| } | |||||
| used_ops.Add(op); | |||||
| }; | |||||
| foreach (var node in graph_def.Node) | |||||
| { | |||||
| mark_op_as_used(node.Op); | |||||
| } | |||||
| return used_ops.ToArray(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,27 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class op_def_registry | |||||
| { | |||||
| private static Dictionary<string, OpDef> _registered_ops; | |||||
| public static Dictionary<string, OpDef> get_registered_ops() | |||||
| { | |||||
| if(_registered_ops == null) | |||||
| { | |||||
| _registered_ops = new Dictionary<string, OpDef>(); | |||||
| var handle = c_api.TF_GetAllOpList(); | |||||
| var buffer = new Buffer(handle); | |||||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||||
| foreach (var op_def in op_list.Op) | |||||
| _registered_ops[op_def.Name] = op_def; | |||||
| } | |||||
| return _registered_ops; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -18,7 +18,7 @@ namespace Tensorflow | |||||
| return buffer; | return buffer; | ||||
| } | } | ||||
| public GraphDef _as_graph_def() | |||||
| public GraphDef _as_graph_def(bool add_shapes = false) | |||||
| { | { | ||||
| var buffer = ToGraphDef(Status); | var buffer = ToGraphDef(Status); | ||||
| Status.Check(); | Status.Check(); | ||||
| @@ -0,0 +1,264 @@ | |||||
| // <auto-generated> | |||||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
| // source: checkpoint_state.proto | |||||
| // </auto-generated> | |||||
| #pragma warning disable 1591, 0612, 3021 | |||||
| #region Designer generated code | |||||
| using pb = global::Google.Protobuf; | |||||
| using pbc = global::Google.Protobuf.Collections; | |||||
| using pbr = global::Google.Protobuf.Reflection; | |||||
| using scg = global::System.Collections.Generic; | |||||
| namespace Tensorflow { | |||||
| /// <summary>Holder for reflection information generated from checkpoint_state.proto</summary> | |||||
| public static partial class CheckpointStateReflection { | |||||
| #region Descriptor | |||||
| /// <summary>File descriptor for checkpoint_state.proto</summary> | |||||
| public static pbr::FileDescriptor Descriptor { | |||||
| get { return descriptor; } | |||||
| } | |||||
| private static pbr::FileDescriptor descriptor; | |||||
| static CheckpointStateReflection() { | |||||
| byte[] descriptorData = global::System.Convert.FromBase64String( | |||||
| string.Concat( | |||||
| "ChZjaGVja3BvaW50X3N0YXRlLnByb3RvEgp0ZW5zb3JmbG93Ip8BCg9DaGVj", | |||||
| "a3BvaW50U3RhdGUSHQoVbW9kZWxfY2hlY2twb2ludF9wYXRoGAEgASgJEiIK", | |||||
| "GmFsbF9tb2RlbF9jaGVja3BvaW50X3BhdGhzGAIgAygJEicKH2FsbF9tb2Rl", | |||||
| "bF9jaGVja3BvaW50X3RpbWVzdGFtcHMYAyADKAESIAoYbGFzdF9wcmVzZXJ2", | |||||
| "ZWRfdGltZXN0YW1wGAQgASgBQgP4AQFiBnByb3RvMw==")); | |||||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||||
| new pbr::FileDescriptor[] { }, | |||||
| new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CheckpointState), global::Tensorflow.CheckpointState.Parser, new[]{ "ModelCheckpointPath", "AllModelCheckpointPaths", "AllModelCheckpointTimestamps", "LastPreservedTimestamp" }, null, null, null) | |||||
| })); | |||||
| } | |||||
| #endregion | |||||
| } | |||||
| #region Messages | |||||
| /// <summary> | |||||
| /// Protocol buffer representing the checkpoint state. | |||||
| /// </summary> | |||||
| public sealed partial class CheckpointState : pb::IMessage<CheckpointState> { | |||||
| private static readonly pb::MessageParser<CheckpointState> _parser = new pb::MessageParser<CheckpointState>(() => new CheckpointState()); | |||||
| private pb::UnknownFieldSet _unknownFields; | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public static pb::MessageParser<CheckpointState> Parser { get { return _parser; } } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public static pbr::MessageDescriptor Descriptor { | |||||
| get { return global::Tensorflow.CheckpointStateReflection.Descriptor.MessageTypes[0]; } | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
| get { return Descriptor; } | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public CheckpointState() { | |||||
| OnConstruction(); | |||||
| } | |||||
| partial void OnConstruction(); | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public CheckpointState(CheckpointState other) : this() { | |||||
| modelCheckpointPath_ = other.modelCheckpointPath_; | |||||
| allModelCheckpointPaths_ = other.allModelCheckpointPaths_.Clone(); | |||||
| allModelCheckpointTimestamps_ = other.allModelCheckpointTimestamps_.Clone(); | |||||
| lastPreservedTimestamp_ = other.lastPreservedTimestamp_; | |||||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public CheckpointState Clone() { | |||||
| return new CheckpointState(this); | |||||
| } | |||||
| /// <summary>Field number for the "model_checkpoint_path" field.</summary> | |||||
| public const int ModelCheckpointPathFieldNumber = 1; | |||||
| private string modelCheckpointPath_ = ""; | |||||
| /// <summary> | |||||
| /// Path to the most-recent model checkpoint. | |||||
| /// </summary> | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public string ModelCheckpointPath { | |||||
| get { return modelCheckpointPath_; } | |||||
| set { | |||||
| modelCheckpointPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||||
| } | |||||
| } | |||||
| /// <summary>Field number for the "all_model_checkpoint_paths" field.</summary> | |||||
| public const int AllModelCheckpointPathsFieldNumber = 2; | |||||
| private static readonly pb::FieldCodec<string> _repeated_allModelCheckpointPaths_codec | |||||
| = pb::FieldCodec.ForString(18); | |||||
| private readonly pbc::RepeatedField<string> allModelCheckpointPaths_ = new pbc::RepeatedField<string>(); | |||||
| /// <summary> | |||||
| /// Paths to all not-yet-deleted model checkpoints, sorted from oldest to | |||||
| /// newest. | |||||
| /// Note that the value of model_checkpoint_path should be the last item in | |||||
| /// this list. | |||||
| /// </summary> | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public pbc::RepeatedField<string> AllModelCheckpointPaths { | |||||
| get { return allModelCheckpointPaths_; } | |||||
| } | |||||
| /// <summary>Field number for the "all_model_checkpoint_timestamps" field.</summary> | |||||
| public const int AllModelCheckpointTimestampsFieldNumber = 3; | |||||
| private static readonly pb::FieldCodec<double> _repeated_allModelCheckpointTimestamps_codec | |||||
| = pb::FieldCodec.ForDouble(26); | |||||
| private readonly pbc::RepeatedField<double> allModelCheckpointTimestamps_ = new pbc::RepeatedField<double>(); | |||||
| /// <summary> | |||||
| /// Unix timestamps corresponding to all_model_checkpoint_paths, indicating | |||||
| /// when each checkpoint was created. | |||||
| /// </summary> | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public pbc::RepeatedField<double> AllModelCheckpointTimestamps { | |||||
| get { return allModelCheckpointTimestamps_; } | |||||
| } | |||||
| /// <summary>Field number for the "last_preserved_timestamp" field.</summary> | |||||
| public const int LastPreservedTimestampFieldNumber = 4; | |||||
| private double lastPreservedTimestamp_; | |||||
| /// <summary> | |||||
| /// Unix timestamp indicating the creation time for the last preserved | |||||
| /// checkpoint. | |||||
| /// </summary> | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public double LastPreservedTimestamp { | |||||
| get { return lastPreservedTimestamp_; } | |||||
| set { | |||||
| lastPreservedTimestamp_ = value; | |||||
| } | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public override bool Equals(object other) { | |||||
| return Equals(other as CheckpointState); | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public bool Equals(CheckpointState other) { | |||||
| if (ReferenceEquals(other, null)) { | |||||
| return false; | |||||
| } | |||||
| if (ReferenceEquals(other, this)) { | |||||
| return true; | |||||
| } | |||||
| if (ModelCheckpointPath != other.ModelCheckpointPath) return false; | |||||
| if(!allModelCheckpointPaths_.Equals(other.allModelCheckpointPaths_)) return false; | |||||
| if(!allModelCheckpointTimestamps_.Equals(other.allModelCheckpointTimestamps_)) return false; | |||||
| if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(LastPreservedTimestamp, other.LastPreservedTimestamp)) return false; | |||||
| return Equals(_unknownFields, other._unknownFields); | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public override int GetHashCode() { | |||||
| int hash = 1; | |||||
| if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode(); | |||||
| hash ^= allModelCheckpointPaths_.GetHashCode(); | |||||
| hash ^= allModelCheckpointTimestamps_.GetHashCode(); | |||||
| if (LastPreservedTimestamp != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(LastPreservedTimestamp); | |||||
| if (_unknownFields != null) { | |||||
| hash ^= _unknownFields.GetHashCode(); | |||||
| } | |||||
| return hash; | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public override string ToString() { | |||||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public void WriteTo(pb::CodedOutputStream output) { | |||||
| if (ModelCheckpointPath.Length != 0) { | |||||
| output.WriteRawTag(10); | |||||
| output.WriteString(ModelCheckpointPath); | |||||
| } | |||||
| allModelCheckpointPaths_.WriteTo(output, _repeated_allModelCheckpointPaths_codec); | |||||
| allModelCheckpointTimestamps_.WriteTo(output, _repeated_allModelCheckpointTimestamps_codec); | |||||
| if (LastPreservedTimestamp != 0D) { | |||||
| output.WriteRawTag(33); | |||||
| output.WriteDouble(LastPreservedTimestamp); | |||||
| } | |||||
| if (_unknownFields != null) { | |||||
| _unknownFields.WriteTo(output); | |||||
| } | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public int CalculateSize() { | |||||
| int size = 0; | |||||
| if (ModelCheckpointPath.Length != 0) { | |||||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(ModelCheckpointPath); | |||||
| } | |||||
| size += allModelCheckpointPaths_.CalculateSize(_repeated_allModelCheckpointPaths_codec); | |||||
| size += allModelCheckpointTimestamps_.CalculateSize(_repeated_allModelCheckpointTimestamps_codec); | |||||
| if (LastPreservedTimestamp != 0D) { | |||||
| size += 1 + 8; | |||||
| } | |||||
| if (_unknownFields != null) { | |||||
| size += _unknownFields.CalculateSize(); | |||||
| } | |||||
| return size; | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public void MergeFrom(CheckpointState other) { | |||||
| if (other == null) { | |||||
| return; | |||||
| } | |||||
| if (other.ModelCheckpointPath.Length != 0) { | |||||
| ModelCheckpointPath = other.ModelCheckpointPath; | |||||
| } | |||||
| allModelCheckpointPaths_.Add(other.allModelCheckpointPaths_); | |||||
| allModelCheckpointTimestamps_.Add(other.allModelCheckpointTimestamps_); | |||||
| if (other.LastPreservedTimestamp != 0D) { | |||||
| LastPreservedTimestamp = other.LastPreservedTimestamp; | |||||
| } | |||||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public void MergeFrom(pb::CodedInputStream input) { | |||||
| uint tag; | |||||
| while ((tag = input.ReadTag()) != 0) { | |||||
| switch(tag) { | |||||
| default: | |||||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
| break; | |||||
| case 10: { | |||||
| ModelCheckpointPath = input.ReadString(); | |||||
| break; | |||||
| } | |||||
| case 18: { | |||||
| allModelCheckpointPaths_.AddEntriesFrom(input, _repeated_allModelCheckpointPaths_codec); | |||||
| break; | |||||
| } | |||||
| case 26: | |||||
| case 25: { | |||||
| allModelCheckpointTimestamps_.AddEntriesFrom(input, _repeated_allModelCheckpointTimestamps_codec); | |||||
| break; | |||||
| } | |||||
| case 33: { | |||||
| LastPreservedTimestamp = input.ReadDouble(); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #endregion | |||||
| } | |||||
| #endregion Designer generated code | |||||
| @@ -1,20 +1,18 @@ | |||||
| ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | ||||
| ```shell | ```shell | ||||
| set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework | |||||
| set SRC_DIR=D:\Projects\tensorflow | |||||
| set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf | set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% versions.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% function.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% graph.proto | |||||
| ``` | |||||
| ```shell | |||||
| set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\protobuf | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% saver.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\resource_handle.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor_shape.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\types.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\attr_value.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\node_def.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\versions.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\function.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\graph.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\saver.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\meta_graph.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\python\training\checkpoint_state.proto | |||||
| ``` | ``` | ||||
| @@ -81,6 +81,11 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public static float time() | |||||
| { | |||||
| return (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds; | |||||
| } | |||||
| public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2) | public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2) | ||||
| { | { | ||||
| int index = 0; | int index = 0; | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -28,6 +29,8 @@ namespace Tensorflow | |||||
| private float _next_checkpoint_time; | private float _next_checkpoint_time; | ||||
| private bool _save_relative_paths; | private bool _save_relative_paths; | ||||
| private bool? _object_restore_saver; | private bool? _object_restore_saver; | ||||
| private Dictionary<string, float> _last_checkpoints; | |||||
| private Dictionary<string, float> _checkpoints_to_be_deleted; | |||||
| public Saver(RefVariable[] var_list = null, | public Saver(RefVariable[] var_list = null, | ||||
| bool reshape = false, | bool reshape = false, | ||||
| @@ -68,6 +71,9 @@ namespace Tensorflow | |||||
| _save_relative_paths = save_relative_paths; | _save_relative_paths = save_relative_paths; | ||||
| _object_restore_saver = null; | _object_restore_saver = null; | ||||
| _last_checkpoints = new Dictionary<string, float>(); | |||||
| _checkpoints_to_be_deleted = new Dictionary<string, float>(); | |||||
| } | } | ||||
| public void build() | public void build() | ||||
| @@ -121,7 +127,7 @@ namespace Tensorflow | |||||
| _check_saver_def(); | _check_saver_def(); | ||||
| _next_checkpoint_time = (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds + _saver_def.KeepCheckpointEveryNHours * 3600; | |||||
| _next_checkpoint_time = Python.time() + _saver_def.KeepCheckpointEveryNHours * 3600; | |||||
| } | } | ||||
| private void _check_saver_def() | private void _check_saver_def() | ||||
| @@ -165,11 +171,119 @@ namespace Tensorflow | |||||
| model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { | model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { | ||||
| new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) | new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) | ||||
| }); | }); | ||||
| if (write_state) | |||||
| { | |||||
| _RecordLastCheckpoint(model_checkpoint_path); | |||||
| checkpoint_management.update_checkpoint_state_internal( | |||||
| save_dir: save_path_parent, | |||||
| model_checkpoint_path: model_checkpoint_path, | |||||
| all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(), | |||||
| latest_filename: latest_filename, | |||||
| save_relative_paths: _save_relative_paths); | |||||
| _MaybeDeleteOldCheckpoints(meta_graph_suffix: meta_graph_suffix); | |||||
| } | |||||
| } | |||||
| if (write_meta_graph) | |||||
| { | |||||
| string meta_graph_filename = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix: meta_graph_suffix); | |||||
| } | |||||
| return _is_empty ? string.Empty : model_checkpoint_path; | |||||
| } | |||||
| /// <summary> | |||||
| /// Writes `MetaGraphDef` to save_path/filename. | |||||
| /// </summary> | |||||
| /// <param name="filename"></param> | |||||
| /// <param name="collection_list"></param> | |||||
| /// <param name="as_text"></param> | |||||
| /// <param name="export_scope"></param> | |||||
| /// <param name="clear_devices"></param> | |||||
| /// <param name="clear_extraneous_savers"></param> | |||||
| /// <param name="strip_default_attrs"></param> | |||||
| public MetaGraphDef export_meta_graph(string filename= "", | |||||
| string[] collection_list = null, | |||||
| string export_scope = "", | |||||
| bool as_text= false, | |||||
| bool clear_devices= false, | |||||
| bool clear_extraneous_savers= false, | |||||
| bool strip_default_attrs= false) | |||||
| { | |||||
| return export_meta_graph( | |||||
| filename: filename, | |||||
| graph_def: ops.get_default_graph()._as_graph_def(add_shapes: true), | |||||
| saver_def: _saver_def, | |||||
| collection_list: collection_list, | |||||
| as_text: as_text, | |||||
| export_scope: export_scope, | |||||
| clear_devices: clear_devices, | |||||
| clear_extraneous_savers: clear_extraneous_savers, | |||||
| strip_default_attrs: strip_default_attrs); | |||||
| } | |||||
| public MetaGraphDef export_meta_graph(string filename = "", | |||||
| byte[] meta_info_def = null, | |||||
| GraphDef graph_def = null, | |||||
| SaverDef saver_def = null, | |||||
| string[] collection_list = null, | |||||
| bool as_text = false, | |||||
| bool clear_devices= false, | |||||
| bool clear_extraneous_savers= false, | |||||
| bool strip_default_attrs= false, | |||||
| string export_scope = "") | |||||
| { | |||||
| var meta_graph_def = meta_graph.export_scoped_meta_graph( | |||||
| filename: filename, | |||||
| meta_info_def: meta_info_def, | |||||
| graph_def: graph_def, | |||||
| saver_def: saver_def, | |||||
| // collection_list: collection_list, | |||||
| as_text: as_text, | |||||
| clear_devices: clear_devices, | |||||
| clear_extraneous_savers: clear_extraneous_savers, | |||||
| strip_default_attrs: strip_default_attrs); | |||||
| return meta_graph_def; | |||||
| } | |||||
| /// <summary> | |||||
| /// Manages the list of the latest checkpoints. | |||||
| /// </summary> | |||||
| /// <param name="latest_save_path"></param> | |||||
| private void _RecordLastCheckpoint(string latest_save_path) | |||||
| { | |||||
| if (_saver_def.MaxToKeep <= 0) return; | |||||
| // Remove first from list if the same name was used before. | |||||
| foreach (var p in _last_checkpoints) | |||||
| if (latest_save_path == _CheckpointFilename((p.Key, p.Value))) | |||||
| _last_checkpoints.Remove(p.Key); | |||||
| // Append new path to list | |||||
| _last_checkpoints.Add(latest_save_path, Python.time()); | |||||
| // If more than max_to_keep, remove oldest. | |||||
| if(_last_checkpoints.Count > _saver_def.MaxToKeep) | |||||
| { | |||||
| var first = _last_checkpoints.First(); | |||||
| _last_checkpoints.Remove(first.Key); | |||||
| _checkpoints_to_be_deleted[first.Key] = first.Value; | |||||
| } | } | ||||
| } | |||||
| throw new NotImplementedException("Saver.save"); | |||||
| private string _CheckpointFilename((string, float) p) | |||||
| { | |||||
| return p.Item1; | |||||
| } | |||||
| /// <summary> | |||||
| /// Deletes old checkpoints if necessary. | |||||
| /// </summary> | |||||
| /// <param name="meta_graph_suffix"></param> | |||||
| private void _MaybeDeleteOldCheckpoints(string meta_graph_suffix = "meta") | |||||
| { | |||||
| return model_checkpoint_path; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,109 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class checkpoint_management | |||||
| { | |||||
| /// <summary> | |||||
| /// Updates the content of the 'checkpoint' file. | |||||
| /// </summary> | |||||
| /// <param name="save_dir">Directory where the model was saved.</param> | |||||
| /// <param name="model_checkpoint_path">The checkpoint file.</param> | |||||
| /// <param name="all_model_checkpoint_paths">List of strings.</param> | |||||
| /// <param name="latest_filename"></param> | |||||
| /// <param name="save_relative_paths"></param> | |||||
| /// <param name="all_model_checkpoint_timestamps"></param> | |||||
| /// <param name="last_preserved_timestamp"></param> | |||||
| public static void update_checkpoint_state_internal(string save_dir, | |||||
| string model_checkpoint_path, | |||||
| List<string> all_model_checkpoint_paths = null, | |||||
| string latest_filename = "", | |||||
| bool save_relative_paths = false, | |||||
| List<float> all_model_checkpoint_timestamps = null, | |||||
| float? last_preserved_timestamp = null | |||||
| ) | |||||
| { | |||||
| CheckpointState ckpt = null; | |||||
| // Writes the "checkpoint" file for the coordinator for later restoration. | |||||
| string coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename); | |||||
| if (save_relative_paths) | |||||
| { | |||||
| throw new NotImplementedException("update_checkpoint_state_internal save_relative_paths"); | |||||
| } | |||||
| else | |||||
| { | |||||
| ckpt = generate_checkpoint_state_proto(save_dir, | |||||
| model_checkpoint_path, | |||||
| all_model_checkpoint_paths, | |||||
| all_model_checkpoint_timestamps, | |||||
| last_preserved_timestamp); | |||||
| } | |||||
| if (coord_checkpoint_filename == ckpt.ModelCheckpointPath) | |||||
| throw new RuntimeError($"Save path '{model_checkpoint_path}' conflicts with path used for " + | |||||
| "checkpoint state. Please use a different save path."); | |||||
| File.WriteAllText(coord_checkpoint_filename, ckpt.ToString()); | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns a filename for storing the CheckpointState. | |||||
| /// </summary> | |||||
| /// <param name="save_dir">The directory for saving and restoring checkpoints.</param> | |||||
| /// <param name="latest_filename"> | |||||
| /// Name of the file in 'save_dir' that is used | |||||
| /// to store the CheckpointState. | |||||
| /// </param> | |||||
| /// <returns>he path of the file that contains the CheckpointState proto.</returns> | |||||
| private static string _GetCheckpointFilename(string save_dir, string latest_filename) | |||||
| { | |||||
| if (string.IsNullOrEmpty(latest_filename)) | |||||
| latest_filename = "checkpoint"; | |||||
| return Path.Combine(save_dir, latest_filename); | |||||
| } | |||||
| private static CheckpointState generate_checkpoint_state_proto(string save_dir, | |||||
| string model_checkpoint_path, | |||||
| List<string> all_model_checkpoint_paths = null, | |||||
| List<float> all_model_checkpoint_timestamps = null, | |||||
| float? last_preserved_timestamp = null) | |||||
| { | |||||
| if (all_model_checkpoint_paths == null) | |||||
| all_model_checkpoint_paths = new List<string>(); | |||||
| // Relative paths need to be rewritten to be relative to the "save_dir" | |||||
| // if model_checkpoint_path already contains "save_dir". | |||||
| all_model_checkpoint_paths.Add(model_checkpoint_path); | |||||
| var coord_checkpoint_proto = new CheckpointState() | |||||
| { | |||||
| ModelCheckpointPath = model_checkpoint_path, | |||||
| LastPreservedTimestamp = last_preserved_timestamp.Value | |||||
| }; | |||||
| coord_checkpoint_proto.AllModelCheckpointPaths.AddRange(all_model_checkpoint_paths); | |||||
| coord_checkpoint_proto.AllModelCheckpointTimestamps.AddRange(all_model_checkpoint_timestamps.Select(x => (double)x)); | |||||
| return coord_checkpoint_proto; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns the meta graph filename. | |||||
| /// </summary> | |||||
| /// <param name="checkpoint_filename"></param> | |||||
| /// <param name="meta_graph_suffix"></param> | |||||
| /// <returns></returns> | |||||
| public static string meta_graph_filename(string checkpoint_filename, string meta_graph_suffix= "meta") | |||||
| { | |||||
| string basename = checkpoint_filename; | |||||
| string suffixed_filename = basename + "." + meta_graph_suffix; | |||||
| return suffixed_filename; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -20,7 +20,8 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| var handle = c_api.TF_GetAllOpList(); | var handle = c_api.TF_GetAllOpList(); | ||||
| var buffer = new Buffer(handle); | var buffer = new Buffer(handle); | ||||
| Assert.IsTrue(buffer.Length == buffer.Length); | |||||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||||
| Assert.IsTrue(op_list.Op.Count > 1000); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||