From c7cf8b60842e223fa0d23ac4835ad279a32ef5e2 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Tue, 12 Feb 2019 11:53:48 -0600 Subject: [PATCH] add op_def_registry to load all registered op list. add RuntimeError exception class. add meta_graph related classes. --- .../Exceptions/RuntimeError.cs | 19 + .../Framework/meta_graph.py.cs | 135 + .../Framework/op_def_registry.py.cs | 27 + src/TensorFlowNET.Core/Graphs/Graph.Export.cs | 2 +- .../Protobuf/CheckpointState.cs | 264 ++ src/TensorFlowNET.Core/Protobuf/MetaGraph.cs | 2679 +++++++++++++++++ src/TensorFlowNET.Core/Protobuf/README.md | 28 +- src/TensorFlowNET.Core/Python.cs | 5 + src/TensorFlowNET.Core/Train/Saving/Saver.cs | 120 +- .../Train/Saving/checkpoint_management.py.cs | 109 + test/TensorFlowNET.UnitTest/OperationsTest.cs | 3 +- 11 files changed, 3371 insertions(+), 20 deletions(-) create mode 100644 src/TensorFlowNET.Core/Exceptions/RuntimeError.cs create mode 100644 src/TensorFlowNET.Core/Framework/meta_graph.py.cs create mode 100644 src/TensorFlowNET.Core/Framework/op_def_registry.py.cs create mode 100644 src/TensorFlowNET.Core/Protobuf/CheckpointState.cs create mode 100644 src/TensorFlowNET.Core/Protobuf/MetaGraph.cs create mode 100644 src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs diff --git a/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs b/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs new file mode 100644 index 00000000..71fd773c --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class RuntimeError : Exception + { + public RuntimeError() : base() + { + + } + + public RuntimeError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs new file mode 100644 index 00000000..8df307d1 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -0,0 +1,135 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.MetaGraphDef.Types; + +namespace Tensorflow +{ + public class meta_graph + { + /// + /// Returns `MetaGraphDef` proto. Optionally writes it to filename. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static MetaGraphDef export_scoped_meta_graph(string filename = "", + GraphDef graph_def = null, + bool as_text = false, + string unbound_inputs_col_name = "unbound_inputs", + bool clear_devices = false, + SaverDef saver_def = null, + bool clear_extraneous_savers= false, + bool strip_default_attrs= false, + byte[] meta_info_def = null) + { + var graph = ops.get_default_graph(); + + var var_list = new Dictionary(); + var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES); + + foreach(var v in variables as RefVariable[]) + { + var_list[v.name] = v; + } + + var scoped_meta_graph_def = create_meta_graph_def( + graph_def: graph_def, + export_scope: "", + exclude_nodes: "", + clear_extraneous_savers: clear_extraneous_savers, + saver_def: saver_def, + strip_default_attrs: strip_default_attrs); + + throw new NotImplementedException("meta_graph.export_scoped_meta_graph"); + } + + private static bool _should_include_node() + { + return true; + } + + private static byte[] create_meta_graph_def(MetaInfoDef meta_info_def = null, + GraphDef graph_def = null, + string export_scope = "", + string exclude_nodes = "", + SaverDef saver_def = null, + bool clear_extraneous_savers = false, + bool strip_default_attrs = false) + { + // Sets graph to default graph if it's not passed in. + var graph = ops.get_default_graph(); + // Creates a MetaGraphDef proto. + var meta_graph_def = new MetaGraphDef(); + if (meta_info_def == null) + meta_info_def = new MetaInfoDef(); + + // Set the tf version strings to the current tf build. + meta_info_def.TensorflowVersion = tf.VERSION; + meta_info_def.TensorflowGitVersion = "unknown"; + meta_graph_def.MetaInfoDef = meta_info_def; + + // Adds graph_def or the default. + if (graph_def == null) + meta_graph_def.GraphDef = graph._as_graph_def(add_shapes: true); + else + meta_graph_def.GraphDef = graph_def; + + // Fills in meta_info_def.stripped_op_list using the ops from graph_def. + if (meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0) + meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef); + + throw new NotImplementedException("create_meta_graph_def"); + } + + private static OpList stripped_op_list_for_graph(GraphDef graph_def) + { + var used_ops = ops_used_by_graph_def(graph_def); + + // Verify that all used ops are registered. + // var registered_ops = op_def_registry.get_registered_ops(); + + var op_list = new OpList(); + /*used_ops.OrderBy(x => x).Select(x => { + + }).ToArray();*/ + + return op_list; + } + + /// + /// Collect the list of ops used by a graph. + /// + /// + /// + private static string[] ops_used_by_graph_def(GraphDef graph_def) + { + var used_ops = new List(); + + Action mark_op_as_used = (op) => + { + if (!used_ops.Contains(op)) + { + + } + + used_ops.Add(op); + }; + + foreach (var node in graph_def.Node) + { + mark_op_as_used(node.Op); + } + + return used_ops.ToArray(); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs new file mode 100644 index 00000000..a10a099a --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class op_def_registry + { + private static Dictionary _registered_ops; + + public static Dictionary get_registered_ops() + { + if(_registered_ops == null) + { + _registered_ops = new Dictionary(); + var handle = c_api.TF_GetAllOpList(); + var buffer = new Buffer(handle); + var op_list = OpList.Parser.ParseFrom(buffer); + + foreach (var op_def in op_list.Op) + _registered_ops[op_def.Name] = op_def; + } + + return _registered_ops; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs index 5809d78f..6501de70 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -18,7 +18,7 @@ namespace Tensorflow return buffer; } - public GraphDef _as_graph_def() + public GraphDef _as_graph_def(bool add_shapes = false) { var buffer = ToGraphDef(Status); Status.Check(); diff --git a/src/TensorFlowNET.Core/Protobuf/CheckpointState.cs b/src/TensorFlowNET.Core/Protobuf/CheckpointState.cs new file mode 100644 index 00000000..51151726 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/CheckpointState.cs @@ -0,0 +1,264 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: checkpoint_state.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from checkpoint_state.proto + public static partial class CheckpointStateReflection { + + #region Descriptor + /// File descriptor for checkpoint_state.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CheckpointStateReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "ChZjaGVja3BvaW50X3N0YXRlLnByb3RvEgp0ZW5zb3JmbG93Ip8BCg9DaGVj", + "a3BvaW50U3RhdGUSHQoVbW9kZWxfY2hlY2twb2ludF9wYXRoGAEgASgJEiIK", + "GmFsbF9tb2RlbF9jaGVja3BvaW50X3BhdGhzGAIgAygJEicKH2FsbF9tb2Rl", + "bF9jaGVja3BvaW50X3RpbWVzdGFtcHMYAyADKAESIAoYbGFzdF9wcmVzZXJ2", + "ZWRfdGltZXN0YW1wGAQgASgBQgP4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CheckpointState), global::Tensorflow.CheckpointState.Parser, new[]{ "ModelCheckpointPath", "AllModelCheckpointPaths", "AllModelCheckpointTimestamps", "LastPreservedTimestamp" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing the checkpoint state. + /// + public sealed partial class CheckpointState : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CheckpointState()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CheckpointStateReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CheckpointState() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CheckpointState(CheckpointState other) : this() { + modelCheckpointPath_ = other.modelCheckpointPath_; + allModelCheckpointPaths_ = other.allModelCheckpointPaths_.Clone(); + allModelCheckpointTimestamps_ = other.allModelCheckpointTimestamps_.Clone(); + lastPreservedTimestamp_ = other.lastPreservedTimestamp_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CheckpointState Clone() { + return new CheckpointState(this); + } + + /// Field number for the "model_checkpoint_path" field. + public const int ModelCheckpointPathFieldNumber = 1; + private string modelCheckpointPath_ = ""; + /// + /// Path to the most-recent model checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ModelCheckpointPath { + get { return modelCheckpointPath_; } + set { + modelCheckpointPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "all_model_checkpoint_paths" field. + public const int AllModelCheckpointPathsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_allModelCheckpointPaths_codec + = pb::FieldCodec.ForString(18); + private readonly pbc::RepeatedField allModelCheckpointPaths_ = new pbc::RepeatedField(); + /// + /// Paths to all not-yet-deleted model checkpoints, sorted from oldest to + /// newest. + /// Note that the value of model_checkpoint_path should be the last item in + /// this list. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField AllModelCheckpointPaths { + get { return allModelCheckpointPaths_; } + } + + /// Field number for the "all_model_checkpoint_timestamps" field. + public const int AllModelCheckpointTimestampsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_allModelCheckpointTimestamps_codec + = pb::FieldCodec.ForDouble(26); + private readonly pbc::RepeatedField allModelCheckpointTimestamps_ = new pbc::RepeatedField(); + /// + /// Unix timestamps corresponding to all_model_checkpoint_paths, indicating + /// when each checkpoint was created. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField AllModelCheckpointTimestamps { + get { return allModelCheckpointTimestamps_; } + } + + /// Field number for the "last_preserved_timestamp" field. + public const int LastPreservedTimestampFieldNumber = 4; + private double lastPreservedTimestamp_; + /// + /// Unix timestamp indicating the creation time for the last preserved + /// checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public double LastPreservedTimestamp { + get { return lastPreservedTimestamp_; } + set { + lastPreservedTimestamp_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as CheckpointState); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(CheckpointState other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ModelCheckpointPath != other.ModelCheckpointPath) return false; + if(!allModelCheckpointPaths_.Equals(other.allModelCheckpointPaths_)) return false; + if(!allModelCheckpointTimestamps_.Equals(other.allModelCheckpointTimestamps_)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(LastPreservedTimestamp, other.LastPreservedTimestamp)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode(); + hash ^= allModelCheckpointPaths_.GetHashCode(); + hash ^= allModelCheckpointTimestamps_.GetHashCode(); + if (LastPreservedTimestamp != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(LastPreservedTimestamp); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ModelCheckpointPath.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ModelCheckpointPath); + } + allModelCheckpointPaths_.WriteTo(output, _repeated_allModelCheckpointPaths_codec); + allModelCheckpointTimestamps_.WriteTo(output, _repeated_allModelCheckpointTimestamps_codec); + if (LastPreservedTimestamp != 0D) { + output.WriteRawTag(33); + output.WriteDouble(LastPreservedTimestamp); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ModelCheckpointPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ModelCheckpointPath); + } + size += allModelCheckpointPaths_.CalculateSize(_repeated_allModelCheckpointPaths_codec); + size += allModelCheckpointTimestamps_.CalculateSize(_repeated_allModelCheckpointTimestamps_codec); + if (LastPreservedTimestamp != 0D) { + size += 1 + 8; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(CheckpointState other) { + if (other == null) { + return; + } + if (other.ModelCheckpointPath.Length != 0) { + ModelCheckpointPath = other.ModelCheckpointPath; + } + allModelCheckpointPaths_.Add(other.allModelCheckpointPaths_); + allModelCheckpointTimestamps_.Add(other.allModelCheckpointTimestamps_); + if (other.LastPreservedTimestamp != 0D) { + LastPreservedTimestamp = other.LastPreservedTimestamp; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ModelCheckpointPath = input.ReadString(); + break; + } + case 18: { + allModelCheckpointPaths_.AddEntriesFrom(input, _repeated_allModelCheckpointPaths_codec); + break; + } + case 26: + case 25: { + allModelCheckpointTimestamps_.AddEntriesFrom(input, _repeated_allModelCheckpointTimestamps_codec); + break; + } + case 33: { + LastPreservedTimestamp = input.ReadDouble(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/MetaGraph.cs b/src/TensorFlowNET.Core/Protobuf/MetaGraph.cs new file mode 100644 index 00000000..4e82f863 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/MetaGraph.cs @@ -0,0 +1,2679 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/meta_graph.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/meta_graph.proto + public static partial class MetaGraphReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/meta_graph.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static MetaGraphReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cil0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvbWV0YV9ncmFwaC5wcm90bxIK", + "dGVuc29yZmxvdxoZZ29vZ2xlL3Byb3RvYnVmL2FueS5wcm90bxoldGVuc29y", + "Zmxvdy9jb3JlL2ZyYW1ld29yay9ncmFwaC5wcm90bxomdGVuc29yZmxvdy9j", + "b3JlL2ZyYW1ld29yay9vcF9kZWYucHJvdG8aLHRlbnNvcmZsb3cvY29yZS9m", + "cmFtZXdvcmsvdGVuc29yX3NoYXBlLnByb3RvGiV0ZW5zb3JmbG93L2NvcmUv", + "ZnJhbWV3b3JrL3R5cGVzLnByb3RvGiR0ZW5zb3JmbG93L2NvcmUvcHJvdG9i", + "dWYvc2F2ZXIucHJvdG8i4wUKDE1ldGFHcmFwaERlZhI7Cg1tZXRhX2luZm9f", + "ZGVmGAEgASgLMiQudGVuc29yZmxvdy5NZXRhR3JhcGhEZWYuTWV0YUluZm9E", + "ZWYSJwoJZ3JhcGhfZGVmGAIgASgLMhQudGVuc29yZmxvdy5HcmFwaERlZhIn", + "CglzYXZlcl9kZWYYAyABKAsyFC50ZW5zb3JmbG93LlNhdmVyRGVmEkMKDmNv", + "bGxlY3Rpb25fZGVmGAQgAygLMisudGVuc29yZmxvdy5NZXRhR3JhcGhEZWYu", + "Q29sbGVjdGlvbkRlZkVudHJ5EkEKDXNpZ25hdHVyZV9kZWYYBSADKAsyKi50", + "ZW5zb3JmbG93Lk1ldGFHcmFwaERlZi5TaWduYXR1cmVEZWZFbnRyeRIwCg5h", + "c3NldF9maWxlX2RlZhgGIAMoCzIYLnRlbnNvcmZsb3cuQXNzZXRGaWxlRGVm", + "GukBCgtNZXRhSW5mb0RlZhIaChJtZXRhX2dyYXBoX3ZlcnNpb24YASABKAkS", + "LAoQc3RyaXBwZWRfb3BfbGlzdBgCIAEoCzISLnRlbnNvcmZsb3cuT3BMaXN0", + "EiYKCGFueV9pbmZvGAMgASgLMhQuZ29vZ2xlLnByb3RvYnVmLkFueRIMCgR0", + "YWdzGAQgAygJEhoKEnRlbnNvcmZsb3dfdmVyc2lvbhgFIAEoCRIeChZ0ZW5z", + "b3JmbG93X2dpdF92ZXJzaW9uGAYgASgJEh4KFnN0cmlwcGVkX2RlZmF1bHRf", + "YXR0cnMYByABKAgaTwoSQ29sbGVjdGlvbkRlZkVudHJ5EgsKA2tleRgBIAEo", + "CRIoCgV2YWx1ZRgCIAEoCzIZLnRlbnNvcmZsb3cuQ29sbGVjdGlvbkRlZjoC", + "OAEaTQoRU2lnbmF0dXJlRGVmRW50cnkSCwoDa2V5GAEgASgJEicKBXZhbHVl", + "GAIgASgLMhgudGVuc29yZmxvdy5TaWduYXR1cmVEZWY6AjgBIt8DCg1Db2xs", + "ZWN0aW9uRGVmEjcKCW5vZGVfbGlzdBgBIAEoCzIiLnRlbnNvcmZsb3cuQ29s", + "bGVjdGlvbkRlZi5Ob2RlTGlzdEgAEjkKCmJ5dGVzX2xpc3QYAiABKAsyIy50", + "ZW5zb3JmbG93LkNvbGxlY3Rpb25EZWYuQnl0ZXNMaXN0SAASOQoKaW50NjRf", + "bGlzdBgDIAEoCzIjLnRlbnNvcmZsb3cuQ29sbGVjdGlvbkRlZi5JbnQ2NExp", + "c3RIABI5CgpmbG9hdF9saXN0GAQgASgLMiMudGVuc29yZmxvdy5Db2xsZWN0", + "aW9uRGVmLkZsb2F0TGlzdEgAEjUKCGFueV9saXN0GAUgASgLMiEudGVuc29y", + "Zmxvdy5Db2xsZWN0aW9uRGVmLkFueUxpc3RIABoZCghOb2RlTGlzdBINCgV2", + "YWx1ZRgBIAMoCRoaCglCeXRlc0xpc3QSDQoFdmFsdWUYASADKAwaHgoJSW50", + "NjRMaXN0EhEKBXZhbHVlGAEgAygDQgIQARoeCglGbG9hdExpc3QSEQoFdmFs", + "dWUYASADKAJCAhABGi4KB0FueUxpc3QSIwoFdmFsdWUYASADKAsyFC5nb29n", + "bGUucHJvdG9idWYuQW55QgYKBGtpbmQioAIKClRlbnNvckluZm8SDgoEbmFt", + "ZRgBIAEoCUgAEjYKCmNvb19zcGFyc2UYBCABKAsyIC50ZW5zb3JmbG93LlRl", + "bnNvckluZm8uQ29vU3BhcnNlSAASIwoFZHR5cGUYAiABKA4yFC50ZW5zb3Jm", + "bG93LkRhdGFUeXBlEjIKDHRlbnNvcl9zaGFwZRgDIAEoCzIcLnRlbnNvcmZs", + "b3cuVGVuc29yU2hhcGVQcm90bxplCglDb29TcGFyc2USGgoSdmFsdWVzX3Rl", + "bnNvcl9uYW1lGAEgASgJEhsKE2luZGljZXNfdGVuc29yX25hbWUYAiABKAkS", + "HwoXZGVuc2Vfc2hhcGVfdGVuc29yX25hbWUYAyABKAlCCgoIZW5jb2Rpbmci", + "oAIKDFNpZ25hdHVyZURlZhI0CgZpbnB1dHMYASADKAsyJC50ZW5zb3JmbG93", + "LlNpZ25hdHVyZURlZi5JbnB1dHNFbnRyeRI2CgdvdXRwdXRzGAIgAygLMiUu", + "dGVuc29yZmxvdy5TaWduYXR1cmVEZWYuT3V0cHV0c0VudHJ5EhMKC21ldGhv", + "ZF9uYW1lGAMgASgJGkUKC0lucHV0c0VudHJ5EgsKA2tleRgBIAEoCRIlCgV2", + "YWx1ZRgCIAEoCzIWLnRlbnNvcmZsb3cuVGVuc29ySW5mbzoCOAEaRgoMT3V0", + "cHV0c0VudHJ5EgsKA2tleRgBIAEoCRIlCgV2YWx1ZRgCIAEoCzIWLnRlbnNv", + "cmZsb3cuVGVuc29ySW5mbzoCOAEiTQoMQXNzZXRGaWxlRGVmEisKC3RlbnNv", + "cl9pbmZvGAEgASgLMhYudGVuc29yZmxvdy5UZW5zb3JJbmZvEhAKCGZpbGVu", + "YW1lGAIgASgJQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IPTWV0YUdy", + "YXBoUHJvdG9zUAFaPGdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93", + "L3RlbnNvcmZsb3cvZ28vY29yZS9wcm90b2J1ZvgBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Google.Protobuf.WellKnownTypes.AnyReflection.Descriptor, global::Tensorflow.GraphReflection.Descriptor, global::Tensorflow.OpDefReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, global::Tensorflow.SaverReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MetaGraphDef), global::Tensorflow.MetaGraphDef.Parser, new[]{ "MetaInfoDef", "GraphDef", "SaverDef", "CollectionDef", "SignatureDef", "AssetFileDef" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MetaGraphDef.Types.MetaInfoDef), global::Tensorflow.MetaGraphDef.Types.MetaInfoDef.Parser, new[]{ "MetaGraphVersion", "StrippedOpList", "AnyInfo", "Tags", "TensorflowVersion", "TensorflowGitVersion", "StrippedDefaultAttrs" }, null, null, null), + null, null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef), global::Tensorflow.CollectionDef.Parser, new[]{ "NodeList", "BytesList", "Int64List", "FloatList", "AnyList" }, new[]{ "Kind" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef.Types.NodeList), global::Tensorflow.CollectionDef.Types.NodeList.Parser, new[]{ "Value" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef.Types.BytesList), global::Tensorflow.CollectionDef.Types.BytesList.Parser, new[]{ "Value" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef.Types.Int64List), global::Tensorflow.CollectionDef.Types.Int64List.Parser, new[]{ "Value" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef.Types.FloatList), global::Tensorflow.CollectionDef.Types.FloatList.Parser, new[]{ "Value" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef.Types.AnyList), global::Tensorflow.CollectionDef.Types.AnyList.Parser, new[]{ "Value" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorInfo), global::Tensorflow.TensorInfo.Parser, new[]{ "Name", "CooSparse", "Dtype", "TensorShape" }, new[]{ "Encoding" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorInfo.Types.CooSparse), global::Tensorflow.TensorInfo.Types.CooSparse.Parser, new[]{ "ValuesTensorName", "IndicesTensorName", "DenseShapeTensorName" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SignatureDef), global::Tensorflow.SignatureDef.Parser, new[]{ "Inputs", "Outputs", "MethodName" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AssetFileDef), global::Tensorflow.AssetFileDef.Parser, new[]{ "TensorInfo", "Filename" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// NOTE: This protocol buffer is evolving, and will go through revisions in the + /// coming months. + /// + /// Protocol buffer containing the following which are necessary to restart + /// training, run inference. It can be used to serialize/de-serialize memory + /// objects necessary for running computation in a graph when crossing the + /// process boundary. It can be used for long term storage of graphs, + /// cross-language execution of graphs, etc. + /// MetaInfoDef + /// GraphDef + /// SaverDef + /// CollectionDef + /// TensorInfo + /// SignatureDef + /// + public sealed partial class MetaGraphDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MetaGraphDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MetaGraphDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MetaGraphDef(MetaGraphDef other) : this() { + metaInfoDef_ = other.metaInfoDef_ != null ? other.metaInfoDef_.Clone() : null; + graphDef_ = other.graphDef_ != null ? other.graphDef_.Clone() : null; + saverDef_ = other.saverDef_ != null ? other.saverDef_.Clone() : null; + collectionDef_ = other.collectionDef_.Clone(); + signatureDef_ = other.signatureDef_.Clone(); + assetFileDef_ = other.assetFileDef_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MetaGraphDef Clone() { + return new MetaGraphDef(this); + } + + /// Field number for the "meta_info_def" field. + public const int MetaInfoDefFieldNumber = 1; + private global::Tensorflow.MetaGraphDef.Types.MetaInfoDef metaInfoDef_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.MetaGraphDef.Types.MetaInfoDef MetaInfoDef { + get { return metaInfoDef_; } + set { + metaInfoDef_ = value; + } + } + + /// Field number for the "graph_def" field. + public const int GraphDefFieldNumber = 2; + private global::Tensorflow.GraphDef graphDef_; + /// + /// GraphDef. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.GraphDef GraphDef { + get { return graphDef_; } + set { + graphDef_ = value; + } + } + + /// Field number for the "saver_def" field. + public const int SaverDefFieldNumber = 3; + private global::Tensorflow.SaverDef saverDef_; + /// + /// SaverDef. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.SaverDef SaverDef { + get { return saverDef_; } + set { + saverDef_ = value; + } + } + + /// Field number for the "collection_def" field. + public const int CollectionDefFieldNumber = 4; + private static readonly pbc::MapField.Codec _map_collectionDef_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.CollectionDef.Parser), 34); + private readonly pbc::MapField collectionDef_ = new pbc::MapField(); + /// + /// collection_def: Map from collection name to collections. + /// See CollectionDef section for details. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField CollectionDef { + get { return collectionDef_; } + } + + /// Field number for the "signature_def" field. + public const int SignatureDefFieldNumber = 5; + private static readonly pbc::MapField.Codec _map_signatureDef_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.SignatureDef.Parser), 42); + private readonly pbc::MapField signatureDef_ = new pbc::MapField(); + /// + /// signature_def: Map from user supplied key for a signature to a single + /// SignatureDef. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField SignatureDef { + get { return signatureDef_; } + } + + /// Field number for the "asset_file_def" field. + public const int AssetFileDefFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_assetFileDef_codec + = pb::FieldCodec.ForMessage(50, global::Tensorflow.AssetFileDef.Parser); + private readonly pbc::RepeatedField assetFileDef_ = new pbc::RepeatedField(); + /// + /// Asset file def to be used with the defined graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField AssetFileDef { + get { return assetFileDef_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as MetaGraphDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(MetaGraphDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(MetaInfoDef, other.MetaInfoDef)) return false; + if (!object.Equals(GraphDef, other.GraphDef)) return false; + if (!object.Equals(SaverDef, other.SaverDef)) return false; + if (!CollectionDef.Equals(other.CollectionDef)) return false; + if (!SignatureDef.Equals(other.SignatureDef)) return false; + if(!assetFileDef_.Equals(other.assetFileDef_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (metaInfoDef_ != null) hash ^= MetaInfoDef.GetHashCode(); + if (graphDef_ != null) hash ^= GraphDef.GetHashCode(); + if (saverDef_ != null) hash ^= SaverDef.GetHashCode(); + hash ^= CollectionDef.GetHashCode(); + hash ^= SignatureDef.GetHashCode(); + hash ^= assetFileDef_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (metaInfoDef_ != null) { + output.WriteRawTag(10); + output.WriteMessage(MetaInfoDef); + } + if (graphDef_ != null) { + output.WriteRawTag(18); + output.WriteMessage(GraphDef); + } + if (saverDef_ != null) { + output.WriteRawTag(26); + output.WriteMessage(SaverDef); + } + collectionDef_.WriteTo(output, _map_collectionDef_codec); + signatureDef_.WriteTo(output, _map_signatureDef_codec); + assetFileDef_.WriteTo(output, _repeated_assetFileDef_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (metaInfoDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(MetaInfoDef); + } + if (graphDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(GraphDef); + } + if (saverDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SaverDef); + } + size += collectionDef_.CalculateSize(_map_collectionDef_codec); + size += signatureDef_.CalculateSize(_map_signatureDef_codec); + size += assetFileDef_.CalculateSize(_repeated_assetFileDef_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(MetaGraphDef other) { + if (other == null) { + return; + } + if (other.metaInfoDef_ != null) { + if (metaInfoDef_ == null) { + metaInfoDef_ = new global::Tensorflow.MetaGraphDef.Types.MetaInfoDef(); + } + MetaInfoDef.MergeFrom(other.MetaInfoDef); + } + if (other.graphDef_ != null) { + if (graphDef_ == null) { + graphDef_ = new global::Tensorflow.GraphDef(); + } + GraphDef.MergeFrom(other.GraphDef); + } + if (other.saverDef_ != null) { + if (saverDef_ == null) { + saverDef_ = new global::Tensorflow.SaverDef(); + } + SaverDef.MergeFrom(other.SaverDef); + } + collectionDef_.Add(other.collectionDef_); + signatureDef_.Add(other.signatureDef_); + assetFileDef_.Add(other.assetFileDef_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (metaInfoDef_ == null) { + metaInfoDef_ = new global::Tensorflow.MetaGraphDef.Types.MetaInfoDef(); + } + input.ReadMessage(metaInfoDef_); + break; + } + case 18: { + if (graphDef_ == null) { + graphDef_ = new global::Tensorflow.GraphDef(); + } + input.ReadMessage(graphDef_); + break; + } + case 26: { + if (saverDef_ == null) { + saverDef_ = new global::Tensorflow.SaverDef(); + } + input.ReadMessage(saverDef_); + break; + } + case 34: { + collectionDef_.AddEntriesFrom(input, _map_collectionDef_codec); + break; + } + case 42: { + signatureDef_.AddEntriesFrom(input, _map_signatureDef_codec); + break; + } + case 50: { + assetFileDef_.AddEntriesFrom(input, _repeated_assetFileDef_codec); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the MetaGraphDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// Meta information regarding the graph to be exported. To be used by users + /// of this protocol buffer to encode information regarding their meta graph. + /// + public sealed partial class MetaInfoDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MetaInfoDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphDef.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MetaInfoDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MetaInfoDef(MetaInfoDef other) : this() { + metaGraphVersion_ = other.metaGraphVersion_; + strippedOpList_ = other.strippedOpList_ != null ? other.strippedOpList_.Clone() : null; + anyInfo_ = other.anyInfo_ != null ? other.anyInfo_.Clone() : null; + tags_ = other.tags_.Clone(); + tensorflowVersion_ = other.tensorflowVersion_; + tensorflowGitVersion_ = other.tensorflowGitVersion_; + strippedDefaultAttrs_ = other.strippedDefaultAttrs_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MetaInfoDef Clone() { + return new MetaInfoDef(this); + } + + /// Field number for the "meta_graph_version" field. + public const int MetaGraphVersionFieldNumber = 1; + private string metaGraphVersion_ = ""; + /// + /// User specified Version string. Can be the name of the model and revision, + /// steps this model has been trained to, etc. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string MetaGraphVersion { + get { return metaGraphVersion_; } + set { + metaGraphVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "stripped_op_list" field. + public const int StrippedOpListFieldNumber = 2; + private global::Tensorflow.OpList strippedOpList_; + /// + /// A copy of the OpDefs used by the producer of this graph_def. + /// Descriptions and Ops not used in graph_def are stripped out. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.OpList StrippedOpList { + get { return strippedOpList_; } + set { + strippedOpList_ = value; + } + } + + /// Field number for the "any_info" field. + public const int AnyInfoFieldNumber = 3; + private global::Google.Protobuf.WellKnownTypes.Any anyInfo_; + /// + /// A serialized protobuf. Can be the time this meta graph is created, or + /// modified, or name of the model. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Google.Protobuf.WellKnownTypes.Any AnyInfo { + get { return anyInfo_; } + set { + anyInfo_ = value; + } + } + + /// Field number for the "tags" field. + public const int TagsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_tags_codec + = pb::FieldCodec.ForString(34); + private readonly pbc::RepeatedField tags_ = new pbc::RepeatedField(); + /// + /// User supplied tag(s) on the meta_graph and included graph_def. + /// + /// MetaGraphDefs should be tagged with their capabilities or use-cases. + /// Examples: "train", "serve", "gpu", "tpu", etc. + /// These tags enable loaders to access the MetaGraph(s) appropriate for a + /// specific use-case or runtime environment. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Tags { + get { return tags_; } + } + + /// Field number for the "tensorflow_version" field. + public const int TensorflowVersionFieldNumber = 5; + private string tensorflowVersion_ = ""; + /// + /// The __version__ string of the tensorflow build used to write this graph. + /// This will be populated by the framework, which will overwrite any user + /// supplied value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TensorflowVersion { + get { return tensorflowVersion_; } + set { + tensorflowVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "tensorflow_git_version" field. + public const int TensorflowGitVersionFieldNumber = 6; + private string tensorflowGitVersion_ = ""; + /// + /// The __git_version__ string of the tensorflow build used to write this + /// graph. This will be populated by the framework, which will overwrite any + /// user supplied value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TensorflowGitVersion { + get { return tensorflowGitVersion_; } + set { + tensorflowGitVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "stripped_default_attrs" field. + public const int StrippedDefaultAttrsFieldNumber = 7; + private bool strippedDefaultAttrs_; + /// + /// A flag to denote whether default-valued attrs have been stripped from + /// the nodes in this graph_def. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool StrippedDefaultAttrs { + get { return strippedDefaultAttrs_; } + set { + strippedDefaultAttrs_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as MetaInfoDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(MetaInfoDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MetaGraphVersion != other.MetaGraphVersion) return false; + if (!object.Equals(StrippedOpList, other.StrippedOpList)) return false; + if (!object.Equals(AnyInfo, other.AnyInfo)) return false; + if(!tags_.Equals(other.tags_)) return false; + if (TensorflowVersion != other.TensorflowVersion) return false; + if (TensorflowGitVersion != other.TensorflowGitVersion) return false; + if (StrippedDefaultAttrs != other.StrippedDefaultAttrs) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MetaGraphVersion.Length != 0) hash ^= MetaGraphVersion.GetHashCode(); + if (strippedOpList_ != null) hash ^= StrippedOpList.GetHashCode(); + if (anyInfo_ != null) hash ^= AnyInfo.GetHashCode(); + hash ^= tags_.GetHashCode(); + if (TensorflowVersion.Length != 0) hash ^= TensorflowVersion.GetHashCode(); + if (TensorflowGitVersion.Length != 0) hash ^= TensorflowGitVersion.GetHashCode(); + if (StrippedDefaultAttrs != false) hash ^= StrippedDefaultAttrs.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MetaGraphVersion.Length != 0) { + output.WriteRawTag(10); + output.WriteString(MetaGraphVersion); + } + if (strippedOpList_ != null) { + output.WriteRawTag(18); + output.WriteMessage(StrippedOpList); + } + if (anyInfo_ != null) { + output.WriteRawTag(26); + output.WriteMessage(AnyInfo); + } + tags_.WriteTo(output, _repeated_tags_codec); + if (TensorflowVersion.Length != 0) { + output.WriteRawTag(42); + output.WriteString(TensorflowVersion); + } + if (TensorflowGitVersion.Length != 0) { + output.WriteRawTag(50); + output.WriteString(TensorflowGitVersion); + } + if (StrippedDefaultAttrs != false) { + output.WriteRawTag(56); + output.WriteBool(StrippedDefaultAttrs); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MetaGraphVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MetaGraphVersion); + } + if (strippedOpList_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(StrippedOpList); + } + if (anyInfo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AnyInfo); + } + size += tags_.CalculateSize(_repeated_tags_codec); + if (TensorflowVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TensorflowVersion); + } + if (TensorflowGitVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TensorflowGitVersion); + } + if (StrippedDefaultAttrs != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(MetaInfoDef other) { + if (other == null) { + return; + } + if (other.MetaGraphVersion.Length != 0) { + MetaGraphVersion = other.MetaGraphVersion; + } + if (other.strippedOpList_ != null) { + if (strippedOpList_ == null) { + strippedOpList_ = new global::Tensorflow.OpList(); + } + StrippedOpList.MergeFrom(other.StrippedOpList); + } + if (other.anyInfo_ != null) { + if (anyInfo_ == null) { + anyInfo_ = new global::Google.Protobuf.WellKnownTypes.Any(); + } + AnyInfo.MergeFrom(other.AnyInfo); + } + tags_.Add(other.tags_); + if (other.TensorflowVersion.Length != 0) { + TensorflowVersion = other.TensorflowVersion; + } + if (other.TensorflowGitVersion.Length != 0) { + TensorflowGitVersion = other.TensorflowGitVersion; + } + if (other.StrippedDefaultAttrs != false) { + StrippedDefaultAttrs = other.StrippedDefaultAttrs; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + MetaGraphVersion = input.ReadString(); + break; + } + case 18: { + if (strippedOpList_ == null) { + strippedOpList_ = new global::Tensorflow.OpList(); + } + input.ReadMessage(strippedOpList_); + break; + } + case 26: { + if (anyInfo_ == null) { + anyInfo_ = new global::Google.Protobuf.WellKnownTypes.Any(); + } + input.ReadMessage(anyInfo_); + break; + } + case 34: { + tags_.AddEntriesFrom(input, _repeated_tags_codec); + break; + } + case 42: { + TensorflowVersion = input.ReadString(); + break; + } + case 50: { + TensorflowGitVersion = input.ReadString(); + break; + } + case 56: { + StrippedDefaultAttrs = input.ReadBool(); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// CollectionDef should cover most collections. + /// To add a user-defined collection, do one of the following: + /// 1. For simple data types, such as string, int, float: + /// tf.add_to_collection("your_collection_name", your_simple_value) + /// strings will be stored as bytes_list. + /// + /// 2. For Protobuf types, there are three ways to add them: + /// 1) tf.add_to_collection("your_collection_name", + /// your_proto.SerializeToString()) + /// + /// collection_def { + /// key: "user_defined_bytes_collection" + /// value { + /// bytes_list { + /// value: "queue_name: \"test_queue\"\n" + /// } + /// } + /// } + /// + /// or + /// + /// 2) tf.add_to_collection("your_collection_name", str(your_proto)) + /// + /// collection_def { + /// key: "user_defined_string_collection" + /// value { + /// bytes_list { + /// value: "\n\ntest_queue" + /// } + /// } + /// } + /// + /// or + /// + /// 3) any_buf = any_pb2.Any() + /// tf.add_to_collection("your_collection_name", + /// any_buf.Pack(your_proto)) + /// + /// collection_def { + /// key: "user_defined_any_collection" + /// value { + /// any_list { + /// value { + /// type_url: "type.googleapis.com/tensorflow.QueueRunnerDef" + /// value: "\n\ntest_queue" + /// } + /// } + /// } + /// } + /// + /// 3. For Python objects, implement to_proto() and from_proto(), and register + /// them in the following manner: + /// ops.register_proto_function("your_collection_name", + /// proto_type, + /// to_proto=YourPythonObject.to_proto, + /// from_proto=YourPythonObject.from_proto) + /// These functions will be invoked to serialize and de-serialize the + /// collection. For example, + /// ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES, + /// proto_type=variable_pb2.VariableDef, + /// to_proto=Variable.to_proto, + /// from_proto=Variable.from_proto) + /// + public sealed partial class CollectionDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CollectionDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CollectionDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CollectionDef(CollectionDef other) : this() { + switch (other.KindCase) { + case KindOneofCase.NodeList: + NodeList = other.NodeList.Clone(); + break; + case KindOneofCase.BytesList: + BytesList = other.BytesList.Clone(); + break; + case KindOneofCase.Int64List: + Int64List = other.Int64List.Clone(); + break; + case KindOneofCase.FloatList: + FloatList = other.FloatList.Clone(); + break; + case KindOneofCase.AnyList: + AnyList = other.AnyList.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CollectionDef Clone() { + return new CollectionDef(this); + } + + /// Field number for the "node_list" field. + public const int NodeListFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.CollectionDef.Types.NodeList NodeList { + get { return kindCase_ == KindOneofCase.NodeList ? (global::Tensorflow.CollectionDef.Types.NodeList) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.NodeList; + } + } + + /// Field number for the "bytes_list" field. + public const int BytesListFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.CollectionDef.Types.BytesList BytesList { + get { return kindCase_ == KindOneofCase.BytesList ? (global::Tensorflow.CollectionDef.Types.BytesList) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.BytesList; + } + } + + /// Field number for the "int64_list" field. + public const int Int64ListFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.CollectionDef.Types.Int64List Int64List { + get { return kindCase_ == KindOneofCase.Int64List ? (global::Tensorflow.CollectionDef.Types.Int64List) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.Int64List; + } + } + + /// Field number for the "float_list" field. + public const int FloatListFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.CollectionDef.Types.FloatList FloatList { + get { return kindCase_ == KindOneofCase.FloatList ? (global::Tensorflow.CollectionDef.Types.FloatList) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.FloatList; + } + } + + /// Field number for the "any_list" field. + public const int AnyListFieldNumber = 5; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.CollectionDef.Types.AnyList AnyList { + get { return kindCase_ == KindOneofCase.AnyList ? (global::Tensorflow.CollectionDef.Types.AnyList) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.AnyList; + } + } + + private object kind_; + /// Enum of possible cases for the "kind" oneof. + public enum KindOneofCase { + None = 0, + NodeList = 1, + BytesList = 2, + Int64List = 3, + FloatList = 4, + AnyList = 5, + } + private KindOneofCase kindCase_ = KindOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public KindOneofCase KindCase { + get { return kindCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearKind() { + kindCase_ = KindOneofCase.None; + kind_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as CollectionDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(CollectionDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(NodeList, other.NodeList)) return false; + if (!object.Equals(BytesList, other.BytesList)) return false; + if (!object.Equals(Int64List, other.Int64List)) return false; + if (!object.Equals(FloatList, other.FloatList)) return false; + if (!object.Equals(AnyList, other.AnyList)) return false; + if (KindCase != other.KindCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (kindCase_ == KindOneofCase.NodeList) hash ^= NodeList.GetHashCode(); + if (kindCase_ == KindOneofCase.BytesList) hash ^= BytesList.GetHashCode(); + if (kindCase_ == KindOneofCase.Int64List) hash ^= Int64List.GetHashCode(); + if (kindCase_ == KindOneofCase.FloatList) hash ^= FloatList.GetHashCode(); + if (kindCase_ == KindOneofCase.AnyList) hash ^= AnyList.GetHashCode(); + hash ^= (int) kindCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (kindCase_ == KindOneofCase.NodeList) { + output.WriteRawTag(10); + output.WriteMessage(NodeList); + } + if (kindCase_ == KindOneofCase.BytesList) { + output.WriteRawTag(18); + output.WriteMessage(BytesList); + } + if (kindCase_ == KindOneofCase.Int64List) { + output.WriteRawTag(26); + output.WriteMessage(Int64List); + } + if (kindCase_ == KindOneofCase.FloatList) { + output.WriteRawTag(34); + output.WriteMessage(FloatList); + } + if (kindCase_ == KindOneofCase.AnyList) { + output.WriteRawTag(42); + output.WriteMessage(AnyList); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (kindCase_ == KindOneofCase.NodeList) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(NodeList); + } + if (kindCase_ == KindOneofCase.BytesList) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BytesList); + } + if (kindCase_ == KindOneofCase.Int64List) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Int64List); + } + if (kindCase_ == KindOneofCase.FloatList) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FloatList); + } + if (kindCase_ == KindOneofCase.AnyList) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AnyList); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(CollectionDef other) { + if (other == null) { + return; + } + switch (other.KindCase) { + case KindOneofCase.NodeList: + if (NodeList == null) { + NodeList = new global::Tensorflow.CollectionDef.Types.NodeList(); + } + NodeList.MergeFrom(other.NodeList); + break; + case KindOneofCase.BytesList: + if (BytesList == null) { + BytesList = new global::Tensorflow.CollectionDef.Types.BytesList(); + } + BytesList.MergeFrom(other.BytesList); + break; + case KindOneofCase.Int64List: + if (Int64List == null) { + Int64List = new global::Tensorflow.CollectionDef.Types.Int64List(); + } + Int64List.MergeFrom(other.Int64List); + break; + case KindOneofCase.FloatList: + if (FloatList == null) { + FloatList = new global::Tensorflow.CollectionDef.Types.FloatList(); + } + FloatList.MergeFrom(other.FloatList); + break; + case KindOneofCase.AnyList: + if (AnyList == null) { + AnyList = new global::Tensorflow.CollectionDef.Types.AnyList(); + } + AnyList.MergeFrom(other.AnyList); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.CollectionDef.Types.NodeList subBuilder = new global::Tensorflow.CollectionDef.Types.NodeList(); + if (kindCase_ == KindOneofCase.NodeList) { + subBuilder.MergeFrom(NodeList); + } + input.ReadMessage(subBuilder); + NodeList = subBuilder; + break; + } + case 18: { + global::Tensorflow.CollectionDef.Types.BytesList subBuilder = new global::Tensorflow.CollectionDef.Types.BytesList(); + if (kindCase_ == KindOneofCase.BytesList) { + subBuilder.MergeFrom(BytesList); + } + input.ReadMessage(subBuilder); + BytesList = subBuilder; + break; + } + case 26: { + global::Tensorflow.CollectionDef.Types.Int64List subBuilder = new global::Tensorflow.CollectionDef.Types.Int64List(); + if (kindCase_ == KindOneofCase.Int64List) { + subBuilder.MergeFrom(Int64List); + } + input.ReadMessage(subBuilder); + Int64List = subBuilder; + break; + } + case 34: { + global::Tensorflow.CollectionDef.Types.FloatList subBuilder = new global::Tensorflow.CollectionDef.Types.FloatList(); + if (kindCase_ == KindOneofCase.FloatList) { + subBuilder.MergeFrom(FloatList); + } + input.ReadMessage(subBuilder); + FloatList = subBuilder; + break; + } + case 42: { + global::Tensorflow.CollectionDef.Types.AnyList subBuilder = new global::Tensorflow.CollectionDef.Types.AnyList(); + if (kindCase_ == KindOneofCase.AnyList) { + subBuilder.MergeFrom(AnyList); + } + input.ReadMessage(subBuilder); + AnyList = subBuilder; + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the CollectionDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// NodeList is used for collecting nodes in graph. For example + /// collection_def { + /// key: "summaries" + /// value { + /// node_list { + /// value: "input_producer/ScalarSummary:0" + /// value: "shuffle_batch/ScalarSummary:0" + /// value: "ImageSummary:0" + /// } + /// } + /// + public sealed partial class NodeList : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NodeList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CollectionDef.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NodeList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NodeList(NodeList other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NodeList Clone() { + return new NodeList(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as NodeList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(NodeList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(NodeList other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + } + + } + + /// + /// BytesList is used for collecting strings and serialized protobufs. For + /// example: + /// collection_def { + /// key: "trainable_variables" + /// value { + /// bytes_list { + /// value: "\n\017conv1/weights:0\022\024conv1/weights/Assign + /// \032\024conv1/weights/read:0" + /// value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032 + /// \023conv1/biases/read:0" + /// } + /// } + /// } + /// + public sealed partial class BytesList : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BytesList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CollectionDef.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BytesList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BytesList(BytesList other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BytesList Clone() { + return new BytesList(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForBytes(10); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BytesList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BytesList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BytesList other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + } + + } + + /// + /// Int64List is used for collecting int, int64 and long values. + /// + public sealed partial class Int64List : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Int64List()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CollectionDef.Descriptor.NestedTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Int64List() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Int64List(Int64List other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Int64List Clone() { + return new Int64List(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Int64List); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Int64List other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Int64List other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + } + + } + + /// + /// FloatList is used for collecting float values. + /// + public sealed partial class FloatList : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FloatList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CollectionDef.Descriptor.NestedTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FloatList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FloatList(FloatList other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FloatList Clone() { + return new FloatList(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForFloat(10); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FloatList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FloatList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FloatList other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 13: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + } + + } + + /// + /// AnyList is used for collecting Any protos. + /// + public sealed partial class AnyList : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AnyList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CollectionDef.Descriptor.NestedTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AnyList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AnyList(AnyList other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AnyList Clone() { + return new AnyList(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForMessage(10, global::Google.Protobuf.WellKnownTypes.Any.Parser); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AnyList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AnyList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AnyList other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// Information about a Tensor necessary for feeding or retrieval. + /// + public sealed partial class TensorInfo : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorInfo(TensorInfo other) : this() { + dtype_ = other.dtype_; + tensorShape_ = other.tensorShape_ != null ? other.tensorShape_.Clone() : null; + switch (other.EncodingCase) { + case EncodingOneofCase.Name: + Name = other.Name; + break; + case EncodingOneofCase.CooSparse: + CooSparse = other.CooSparse.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorInfo Clone() { + return new TensorInfo(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + /// + /// For dense `Tensor`s, the name of the tensor in the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return encodingCase_ == EncodingOneofCase.Name ? (string) encoding_ : ""; } + set { + encoding_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + encodingCase_ = EncodingOneofCase.Name; + } + } + + /// Field number for the "coo_sparse" field. + public const int CooSparseFieldNumber = 4; + /// + /// There are many possible encodings of sparse matrices + /// (https://en.wikipedia.org/wiki/Sparse_matrix). Currently, TensorFlow + /// uses only the COO encoding. This is supported and documented in the + /// SparseTensor Python class. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.TensorInfo.Types.CooSparse CooSparse { + get { return encodingCase_ == EncodingOneofCase.CooSparse ? (global::Tensorflow.TensorInfo.Types.CooSparse) encoding_ : null; } + set { + encoding_ = value; + encodingCase_ = value == null ? EncodingOneofCase.None : EncodingOneofCase.CooSparse; + } + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 2; + private global::Tensorflow.DataType dtype_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + /// Field number for the "tensor_shape" field. + public const int TensorShapeFieldNumber = 3; + private global::Tensorflow.TensorShapeProto tensorShape_; + /// + /// The static shape should be recorded here, to the extent that it can + /// be known in advance. In the case of a SparseTensor, this field describes + /// the logical shape of the represented tensor (aka dense_shape). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.TensorShapeProto TensorShape { + get { return tensorShape_; } + set { + tensorShape_ = value; + } + } + + private object encoding_; + /// Enum of possible cases for the "encoding" oneof. + public enum EncodingOneofCase { + None = 0, + Name = 1, + CooSparse = 4, + } + private EncodingOneofCase encodingCase_ = EncodingOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EncodingOneofCase EncodingCase { + get { return encodingCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearEncoding() { + encodingCase_ = EncodingOneofCase.None; + encoding_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TensorInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TensorInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (!object.Equals(CooSparse, other.CooSparse)) return false; + if (Dtype != other.Dtype) return false; + if (!object.Equals(TensorShape, other.TensorShape)) return false; + if (EncodingCase != other.EncodingCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (encodingCase_ == EncodingOneofCase.Name) hash ^= Name.GetHashCode(); + if (encodingCase_ == EncodingOneofCase.CooSparse) hash ^= CooSparse.GetHashCode(); + if (Dtype != 0) hash ^= Dtype.GetHashCode(); + if (tensorShape_ != null) hash ^= TensorShape.GetHashCode(); + hash ^= (int) encodingCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (encodingCase_ == EncodingOneofCase.Name) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Dtype != 0) { + output.WriteRawTag(16); + output.WriteEnum((int) Dtype); + } + if (tensorShape_ != null) { + output.WriteRawTag(26); + output.WriteMessage(TensorShape); + } + if (encodingCase_ == EncodingOneofCase.CooSparse) { + output.WriteRawTag(34); + output.WriteMessage(CooSparse); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (encodingCase_ == EncodingOneofCase.Name) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (encodingCase_ == EncodingOneofCase.CooSparse) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CooSparse); + } + if (Dtype != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (tensorShape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TensorShape); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TensorInfo other) { + if (other == null) { + return; + } + if (other.Dtype != 0) { + Dtype = other.Dtype; + } + if (other.tensorShape_ != null) { + if (tensorShape_ == null) { + tensorShape_ = new global::Tensorflow.TensorShapeProto(); + } + TensorShape.MergeFrom(other.TensorShape); + } + switch (other.EncodingCase) { + case EncodingOneofCase.Name: + Name = other.Name; + break; + case EncodingOneofCase.CooSparse: + if (CooSparse == null) { + CooSparse = new global::Tensorflow.TensorInfo.Types.CooSparse(); + } + CooSparse.MergeFrom(other.CooSparse); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + dtype_ = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 26: { + if (tensorShape_ == null) { + tensorShape_ = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(tensorShape_); + break; + } + case 34: { + global::Tensorflow.TensorInfo.Types.CooSparse subBuilder = new global::Tensorflow.TensorInfo.Types.CooSparse(); + if (encodingCase_ == EncodingOneofCase.CooSparse) { + subBuilder.MergeFrom(CooSparse); + } + input.ReadMessage(subBuilder); + CooSparse = subBuilder; + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the TensorInfo message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// For sparse tensors, The COO encoding stores a triple of values, indices, + /// and shape. + /// + public sealed partial class CooSparse : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CooSparse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TensorInfo.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CooSparse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CooSparse(CooSparse other) : this() { + valuesTensorName_ = other.valuesTensorName_; + indicesTensorName_ = other.indicesTensorName_; + denseShapeTensorName_ = other.denseShapeTensorName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CooSparse Clone() { + return new CooSparse(this); + } + + /// Field number for the "values_tensor_name" field. + public const int ValuesTensorNameFieldNumber = 1; + private string valuesTensorName_ = ""; + /// + /// The shape of the values Tensor is [?]. Its dtype must be the dtype of + /// the SparseTensor as a whole, given in the enclosing TensorInfo. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ValuesTensorName { + get { return valuesTensorName_; } + set { + valuesTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "indices_tensor_name" field. + public const int IndicesTensorNameFieldNumber = 2; + private string indicesTensorName_ = ""; + /// + /// The indices Tensor must have dtype int64 and shape [?, ?]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string IndicesTensorName { + get { return indicesTensorName_; } + set { + indicesTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "dense_shape_tensor_name" field. + public const int DenseShapeTensorNameFieldNumber = 3; + private string denseShapeTensorName_ = ""; + /// + /// The dynamic logical shape represented by the SparseTensor is recorded in + /// the Tensor referenced here. It must have dtype int64 and shape [?]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string DenseShapeTensorName { + get { return denseShapeTensorName_; } + set { + denseShapeTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as CooSparse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(CooSparse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ValuesTensorName != other.ValuesTensorName) return false; + if (IndicesTensorName != other.IndicesTensorName) return false; + if (DenseShapeTensorName != other.DenseShapeTensorName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ValuesTensorName.Length != 0) hash ^= ValuesTensorName.GetHashCode(); + if (IndicesTensorName.Length != 0) hash ^= IndicesTensorName.GetHashCode(); + if (DenseShapeTensorName.Length != 0) hash ^= DenseShapeTensorName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ValuesTensorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ValuesTensorName); + } + if (IndicesTensorName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(IndicesTensorName); + } + if (DenseShapeTensorName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(DenseShapeTensorName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ValuesTensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ValuesTensorName); + } + if (IndicesTensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(IndicesTensorName); + } + if (DenseShapeTensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DenseShapeTensorName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(CooSparse other) { + if (other == null) { + return; + } + if (other.ValuesTensorName.Length != 0) { + ValuesTensorName = other.ValuesTensorName; + } + if (other.IndicesTensorName.Length != 0) { + IndicesTensorName = other.IndicesTensorName; + } + if (other.DenseShapeTensorName.Length != 0) { + DenseShapeTensorName = other.DenseShapeTensorName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ValuesTensorName = input.ReadString(); + break; + } + case 18: { + IndicesTensorName = input.ReadString(); + break; + } + case 26: { + DenseShapeTensorName = input.ReadString(); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// SignatureDef defines the signature of a computation supported by a TensorFlow + /// graph. + /// + /// For example, a model with two loss computations, sharing a single input, + /// might have the following signature_def map. + /// + /// Note that across the two SignatureDefs "loss_A" and "loss_B", the input key, + /// output key, and method_name are identical, and will be used by system(s) that + /// implement or rely upon this particular loss method. The output tensor names + /// differ, demonstrating how different outputs can exist for the same method. + /// + /// signature_def { + /// key: "loss_A" + /// value { + /// inputs { + /// key: "input" + /// value { + /// name: "input:0" + /// dtype: DT_STRING + /// tensor_shape: ... + /// } + /// } + /// outputs { + /// key: "loss_output" + /// value { + /// name: "loss_output_A:0" + /// dtype: DT_FLOAT + /// tensor_shape: ... + /// } + /// } + /// } + /// ... + /// method_name: "some/package/compute_loss" + /// } + /// signature_def { + /// key: "loss_B" + /// value { + /// inputs { + /// key: "input" + /// value { + /// name: "input:0" + /// dtype: DT_STRING + /// tensor_shape: ... + /// } + /// } + /// outputs { + /// key: "loss_output" + /// value { + /// name: "loss_output_B:0" + /// dtype: DT_FLOAT + /// tensor_shape: ... + /// } + /// } + /// } + /// ... + /// method_name: "some/package/compute_loss" + /// } + /// + public sealed partial class SignatureDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SignatureDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SignatureDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SignatureDef(SignatureDef other) : this() { + inputs_ = other.inputs_.Clone(); + outputs_ = other.outputs_.Clone(); + methodName_ = other.methodName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SignatureDef Clone() { + return new SignatureDef(this); + } + + /// Field number for the "inputs" field. + public const int InputsFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_inputs_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.TensorInfo.Parser), 10); + private readonly pbc::MapField inputs_ = new pbc::MapField(); + /// + /// Named input parameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField Inputs { + get { return inputs_; } + } + + /// Field number for the "outputs" field. + public const int OutputsFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_outputs_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.TensorInfo.Parser), 18); + private readonly pbc::MapField outputs_ = new pbc::MapField(); + /// + /// Named output parameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField Outputs { + get { return outputs_; } + } + + /// Field number for the "method_name" field. + public const int MethodNameFieldNumber = 3; + private string methodName_ = ""; + /// + /// Extensible method_name information enabling third-party users to mark a + /// SignatureDef as supporting a particular method. This enables producers and + /// consumers of SignatureDefs, e.g. a model definition library and a serving + /// library to have a clear hand-off regarding the semantics of a computation. + /// + /// Note that multiple SignatureDefs in a single MetaGraphDef may have the same + /// method_name. This is commonly used to support multi-headed computation, + /// where a single graph computation may return multiple results. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string MethodName { + get { return methodName_; } + set { + methodName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SignatureDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SignatureDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!Inputs.Equals(other.Inputs)) return false; + if (!Outputs.Equals(other.Outputs)) return false; + if (MethodName != other.MethodName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= Inputs.GetHashCode(); + hash ^= Outputs.GetHashCode(); + if (MethodName.Length != 0) hash ^= MethodName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + inputs_.WriteTo(output, _map_inputs_codec); + outputs_.WriteTo(output, _map_outputs_codec); + if (MethodName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(MethodName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += inputs_.CalculateSize(_map_inputs_codec); + size += outputs_.CalculateSize(_map_outputs_codec); + if (MethodName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MethodName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SignatureDef other) { + if (other == null) { + return; + } + inputs_.Add(other.inputs_); + outputs_.Add(other.outputs_); + if (other.MethodName.Length != 0) { + MethodName = other.MethodName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + inputs_.AddEntriesFrom(input, _map_inputs_codec); + break; + } + case 18: { + outputs_.AddEntriesFrom(input, _map_outputs_codec); + break; + } + case 26: { + MethodName = input.ReadString(); + break; + } + } + } + } + + } + + /// + /// An asset file def for a single file or a set of sharded files with the same + /// name. + /// + public sealed partial class AssetFileDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AssetFileDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AssetFileDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AssetFileDef(AssetFileDef other) : this() { + tensorInfo_ = other.tensorInfo_ != null ? other.tensorInfo_.Clone() : null; + filename_ = other.filename_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AssetFileDef Clone() { + return new AssetFileDef(this); + } + + /// Field number for the "tensor_info" field. + public const int TensorInfoFieldNumber = 1; + private global::Tensorflow.TensorInfo tensorInfo_; + /// + /// The tensor to bind the asset filename to. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.TensorInfo TensorInfo { + get { return tensorInfo_; } + set { + tensorInfo_ = value; + } + } + + /// Field number for the "filename" field. + public const int FilenameFieldNumber = 2; + private string filename_ = ""; + /// + /// The filename within an assets directory. Note: does not include the path + /// prefix, i.e. directories. For an asset at /tmp/path/vocab.txt, the filename + /// would be "vocab.txt". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Filename { + get { return filename_; } + set { + filename_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AssetFileDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AssetFileDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(TensorInfo, other.TensorInfo)) return false; + if (Filename != other.Filename) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (tensorInfo_ != null) hash ^= TensorInfo.GetHashCode(); + if (Filename.Length != 0) hash ^= Filename.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (tensorInfo_ != null) { + output.WriteRawTag(10); + output.WriteMessage(TensorInfo); + } + if (Filename.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Filename); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (tensorInfo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TensorInfo); + } + if (Filename.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Filename); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AssetFileDef other) { + if (other == null) { + return; + } + if (other.tensorInfo_ != null) { + if (tensorInfo_ == null) { + tensorInfo_ = new global::Tensorflow.TensorInfo(); + } + TensorInfo.MergeFrom(other.TensorInfo); + } + if (other.Filename.Length != 0) { + Filename = other.Filename; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (tensorInfo_ == null) { + tensorInfo_ = new global::Tensorflow.TensorInfo(); + } + input.ReadMessage(tensorInfo_); + break; + } + case 18: { + Filename = input.ReadString(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/README.md b/src/TensorFlowNET.Core/Protobuf/README.md index c4ca96ca..64f6e813 100644 --- a/src/TensorFlowNET.Core/Protobuf/README.md +++ b/src/TensorFlowNET.Core/Protobuf/README.md @@ -1,20 +1,18 @@ ### Download compiler from https://github.com/protocolbuffers/protobuf/releases ```shell -set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework +set SRC_DIR=D:\Projects\tensorflow set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf -protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto -protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto -protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto -protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto -protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto -protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto -protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% versions.proto -protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% function.proto -protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% graph.proto -``` - -```shell -set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\protobuf -protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% saver.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\resource_handle.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor_shape.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\types.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\attr_value.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\node_def.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\versions.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\function.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\graph.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\saver.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\meta_graph.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\python\training\checkpoint_state.proto ``` diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index a47ac262..b077bfc3 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -81,6 +81,11 @@ namespace Tensorflow } } + public static float time() + { + return (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds; + } + public static IEnumerable<(T, T)> zip(NDArray t1, NDArray t2) { int index = 0; diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index b1ed3322..f6b6b8af 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Text; namespace Tensorflow @@ -28,6 +29,8 @@ namespace Tensorflow private float _next_checkpoint_time; private bool _save_relative_paths; private bool? _object_restore_saver; + private Dictionary _last_checkpoints; + private Dictionary _checkpoints_to_be_deleted; public Saver(RefVariable[] var_list = null, bool reshape = false, @@ -68,6 +71,9 @@ namespace Tensorflow _save_relative_paths = save_relative_paths; _object_restore_saver = null; + + _last_checkpoints = new Dictionary(); + _checkpoints_to_be_deleted = new Dictionary(); } public void build() @@ -121,7 +127,7 @@ namespace Tensorflow _check_saver_def(); - _next_checkpoint_time = (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds + _saver_def.KeepCheckpointEveryNHours * 3600; + _next_checkpoint_time = Python.time() + _saver_def.KeepCheckpointEveryNHours * 3600; } private void _check_saver_def() @@ -165,11 +171,119 @@ namespace Tensorflow model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) }); + + if (write_state) + { + _RecordLastCheckpoint(model_checkpoint_path); + checkpoint_management.update_checkpoint_state_internal( + save_dir: save_path_parent, + model_checkpoint_path: model_checkpoint_path, + all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(), + latest_filename: latest_filename, + save_relative_paths: _save_relative_paths); + _MaybeDeleteOldCheckpoints(meta_graph_suffix: meta_graph_suffix); + } + } + + if (write_meta_graph) + { + string meta_graph_filename = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix: meta_graph_suffix); + } + + return _is_empty ? string.Empty : model_checkpoint_path; + } + + /// + /// Writes `MetaGraphDef` to save_path/filename. + /// + /// + /// + /// + /// + /// + /// + /// + public MetaGraphDef export_meta_graph(string filename= "", + string[] collection_list = null, + string export_scope = "", + bool as_text= false, + bool clear_devices= false, + bool clear_extraneous_savers= false, + bool strip_default_attrs= false) + { + return export_meta_graph( + filename: filename, + graph_def: ops.get_default_graph()._as_graph_def(add_shapes: true), + saver_def: _saver_def, + collection_list: collection_list, + as_text: as_text, + export_scope: export_scope, + clear_devices: clear_devices, + clear_extraneous_savers: clear_extraneous_savers, + strip_default_attrs: strip_default_attrs); + } + + public MetaGraphDef export_meta_graph(string filename = "", + byte[] meta_info_def = null, + GraphDef graph_def = null, + SaverDef saver_def = null, + string[] collection_list = null, + bool as_text = false, + bool clear_devices= false, + bool clear_extraneous_savers= false, + bool strip_default_attrs= false, + string export_scope = "") + { + var meta_graph_def = meta_graph.export_scoped_meta_graph( + filename: filename, + meta_info_def: meta_info_def, + graph_def: graph_def, + saver_def: saver_def, + // collection_list: collection_list, + as_text: as_text, + clear_devices: clear_devices, + clear_extraneous_savers: clear_extraneous_savers, + strip_default_attrs: strip_default_attrs); + return meta_graph_def; + } + + /// + /// Manages the list of the latest checkpoints. + /// + /// + private void _RecordLastCheckpoint(string latest_save_path) + { + if (_saver_def.MaxToKeep <= 0) return; + + // Remove first from list if the same name was used before. + foreach (var p in _last_checkpoints) + if (latest_save_path == _CheckpointFilename((p.Key, p.Value))) + _last_checkpoints.Remove(p.Key); + + // Append new path to list + _last_checkpoints.Add(latest_save_path, Python.time()); + + // If more than max_to_keep, remove oldest. + if(_last_checkpoints.Count > _saver_def.MaxToKeep) + { + var first = _last_checkpoints.First(); + _last_checkpoints.Remove(first.Key); + _checkpoints_to_be_deleted[first.Key] = first.Value; } + } - throw new NotImplementedException("Saver.save"); + private string _CheckpointFilename((string, float) p) + { + return p.Item1; + } + + /// + /// Deletes old checkpoints if necessary. + /// + /// + private void _MaybeDeleteOldCheckpoints(string meta_graph_suffix = "meta") + { - return model_checkpoint_path; } } } diff --git a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs new file mode 100644 index 00000000..f301a027 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; + +namespace Tensorflow +{ + public class checkpoint_management + { + /// + /// Updates the content of the 'checkpoint' file. + /// + /// Directory where the model was saved. + /// The checkpoint file. + /// List of strings. + /// + /// + /// + /// + public static void update_checkpoint_state_internal(string save_dir, + string model_checkpoint_path, + List all_model_checkpoint_paths = null, + string latest_filename = "", + bool save_relative_paths = false, + List all_model_checkpoint_timestamps = null, + float? last_preserved_timestamp = null + ) + { + CheckpointState ckpt = null; + + // Writes the "checkpoint" file for the coordinator for later restoration. + string coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename); + if (save_relative_paths) + { + throw new NotImplementedException("update_checkpoint_state_internal save_relative_paths"); + } + else + { + ckpt = generate_checkpoint_state_proto(save_dir, + model_checkpoint_path, + all_model_checkpoint_paths, + all_model_checkpoint_timestamps, + last_preserved_timestamp); + } + + if (coord_checkpoint_filename == ckpt.ModelCheckpointPath) + throw new RuntimeError($"Save path '{model_checkpoint_path}' conflicts with path used for " + + "checkpoint state. Please use a different save path."); + + File.WriteAllText(coord_checkpoint_filename, ckpt.ToString()); + } + + /// + /// Returns a filename for storing the CheckpointState. + /// + /// The directory for saving and restoring checkpoints. + /// + /// Name of the file in 'save_dir' that is used + /// to store the CheckpointState. + /// + /// he path of the file that contains the CheckpointState proto. + private static string _GetCheckpointFilename(string save_dir, string latest_filename) + { + if (string.IsNullOrEmpty(latest_filename)) + latest_filename = "checkpoint"; + + return Path.Combine(save_dir, latest_filename); + } + + private static CheckpointState generate_checkpoint_state_proto(string save_dir, + string model_checkpoint_path, + List all_model_checkpoint_paths = null, + List all_model_checkpoint_timestamps = null, + float? last_preserved_timestamp = null) + { + if (all_model_checkpoint_paths == null) + all_model_checkpoint_paths = new List(); + + // Relative paths need to be rewritten to be relative to the "save_dir" + // if model_checkpoint_path already contains "save_dir". + all_model_checkpoint_paths.Add(model_checkpoint_path); + + var coord_checkpoint_proto = new CheckpointState() + { + ModelCheckpointPath = model_checkpoint_path, + LastPreservedTimestamp = last_preserved_timestamp.Value + }; + + coord_checkpoint_proto.AllModelCheckpointPaths.AddRange(all_model_checkpoint_paths); + coord_checkpoint_proto.AllModelCheckpointTimestamps.AddRange(all_model_checkpoint_timestamps.Select(x => (double)x)); + + return coord_checkpoint_proto; + } + + /// + /// Returns the meta graph filename. + /// + /// + /// + /// + public static string meta_graph_filename(string checkpoint_filename, string meta_graph_suffix= "meta") + { + string basename = checkpoint_filename; + string suffixed_filename = basename + "." + meta_graph_suffix; + return suffixed_filename; + } + } +} diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 36c7b1f9..f0d0c8bd 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -20,7 +20,8 @@ namespace TensorFlowNET.UnitTest { var handle = c_api.TF_GetAllOpList(); var buffer = new Buffer(handle); - Assert.IsTrue(buffer.Length == buffer.Length); + var op_list = OpList.Parser.ParseFrom(buffer); + Assert.IsTrue(op_list.Op.Count > 1000); } [TestMethod]