| @@ -2,4 +2,11 @@ | |||
| ### Saver | |||
| The `tf.train.saver` class provides methods to save and restore models. | |||
| The `tf.train.saver` class provides methods to save and restore models. | |||
| ### Saver Builder | |||
| ##### Bulk Saver Builder | |||
| @@ -0,0 +1,21 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public interface IPyClass | |||
| { | |||
| /// <summary> | |||
| /// Called when the instance is created. | |||
| /// </summary> | |||
| /// <param name="args"></param> | |||
| void __init__(IPyClass self, dynamic args); | |||
| void __enter__(IPyClass self); | |||
| void __exit__(IPyClass self); | |||
| void __del__(IPyClass self); | |||
| } | |||
| } | |||
| @@ -116,6 +116,10 @@ namespace Tensorflow | |||
| values = new Tensor[] { keywords[input_name] as Tensor }; | |||
| } | |||
| inputs.AddRange(values as Tensor[]); | |||
| base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); | |||
| input_types.AddRange(base_types); | |||
| if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||
| { | |||
| if (attrs.ContainsKey(input_arg.NumberAttr)) | |||
| @@ -144,10 +148,32 @@ namespace Tensorflow | |||
| var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr); | |||
| } | |||
| } | |||
| else if (!string.IsNullOrEmpty(input_arg.TypeAttr)) | |||
| { | |||
| var attr_value = base_types[0]; | |||
| if (attrs.ContainsKey(input_arg.TypeAttr)) | |||
| { | |||
| inputs.AddRange(values as Tensor[]); | |||
| base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); | |||
| input_types.AddRange(base_types); | |||
| } | |||
| else | |||
| { | |||
| attrs[input_arg.TypeAttr] = attr_value; | |||
| inferred_from[input_arg.TypeAttr] = input_name; | |||
| } | |||
| } | |||
| else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||
| { | |||
| var attr_value = base_types; | |||
| if (attrs.ContainsKey(input_arg.TypeListAttr)) | |||
| { | |||
| } | |||
| else | |||
| { | |||
| attrs[input_arg.TypeListAttr] = attr_value; | |||
| inferred_from[input_arg.TypeListAttr] = input_name; | |||
| } | |||
| } | |||
| } | |||
| // Process remaining attrs | |||
| @@ -213,6 +239,11 @@ namespace Tensorflow | |||
| case "type": | |||
| attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||
| break; | |||
| case "list(type)": | |||
| if (attr_value.List == null) | |||
| attr_value.List = new AttrValue.Types.ListValue(); | |||
| attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | |||
| break; | |||
| case "bool": | |||
| attr_value.B = (bool)value; | |||
| break; | |||
| @@ -225,9 +256,14 @@ namespace Tensorflow | |||
| throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}."); | |||
| break; | |||
| case "shape": | |||
| attr_value.Shape = value == null ? | |||
| attr_def.DefaultValue.Shape : | |||
| tensor_util.as_shape((long[])value); | |||
| if (value == null && attr_def.DefaultValue != null) | |||
| attr_value.Shape = attr_def.DefaultValue.Shape; | |||
| if(value is TensorShape val1) | |||
| attr_value.Shape = val1.as_proto(); | |||
| else if(value is long[] val2) | |||
| attr_value.Shape = tensor_util.as_shape(val2); | |||
| break; | |||
| default: | |||
| throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | |||
| @@ -106,6 +106,19 @@ namespace Tensorflow | |||
| throw new NotImplementedException("where"); | |||
| } | |||
| /// <summary> | |||
| /// A placeholder op that passes through `input` when its output is not fed. | |||
| /// </summary> | |||
| /// <param name="input">The default value to produce when output is not fed.</param> | |||
| /// <param name="shape"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor placeholder_with_default<T>(T input, TensorShape shape, string name = "") | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name }); | |||
| return _op.outputs[0]; | |||
| } | |||
| public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = "") | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e }); | |||
| @@ -0,0 +1,18 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class gen_io_ops | |||
| { | |||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = "") | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | |||
| return _op; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,401 @@ | |||
| // <auto-generated> | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: saver.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using pbr = global::Google.Protobuf.Reflection; | |||
| using scg = global::System.Collections.Generic; | |||
| namespace Tensorflow { | |||
| /// <summary>Holder for reflection information generated from saver.proto</summary> | |||
| public static partial class SaverReflection { | |||
| #region Descriptor | |||
| /// <summary>File descriptor for saver.proto</summary> | |||
| public static pbr::FileDescriptor Descriptor { | |||
| get { return descriptor; } | |||
| } | |||
| private static pbr::FileDescriptor descriptor; | |||
| static SaverReflection() { | |||
| byte[] descriptorData = global::System.Convert.FromBase64String( | |||
| string.Concat( | |||
| "CgtzYXZlci5wcm90bxIKdGVuc29yZmxvdyKeAgoIU2F2ZXJEZWYSHAoUZmls", | |||
| "ZW5hbWVfdGVuc29yX25hbWUYASABKAkSGAoQc2F2ZV90ZW5zb3JfbmFtZRgC", | |||
| "IAEoCRIXCg9yZXN0b3JlX29wX25hbWUYAyABKAkSEwoLbWF4X3RvX2tlZXAY", | |||
| "BCABKAUSDwoHc2hhcmRlZBgFIAEoCBIlCh1rZWVwX2NoZWNrcG9pbnRfZXZl", | |||
| "cnlfbl9ob3VycxgGIAEoAhI9Cgd2ZXJzaW9uGAcgASgOMiwudGVuc29yZmxv", | |||
| "dy5TYXZlckRlZi5DaGVja3BvaW50Rm9ybWF0VmVyc2lvbiI1ChdDaGVja3Bv", | |||
| "aW50Rm9ybWF0VmVyc2lvbhIKCgZMRUdBQ1kQABIGCgJWMRABEgYKAlYyEAJC", | |||
| "ZQoTb3JnLnRlbnNvcmZsb3cudXRpbEILU2F2ZXJQcm90b3NQAVo8Z2l0aHVi", | |||
| "LmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3Jl", | |||
| "L3Byb3RvYnVm+AEBYgZwcm90bzM=")); | |||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
| new pbr::FileDescriptor[] { }, | |||
| new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaverDef), global::Tensorflow.SaverDef.Parser, new[]{ "FilenameTensorName", "SaveTensorName", "RestoreOpName", "MaxToKeep", "Sharded", "KeepCheckpointEveryNHours", "Version" }, null, new[]{ typeof(global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) }, null) | |||
| })); | |||
| } | |||
| #endregion | |||
| } | |||
| #region Messages | |||
| /// <summary> | |||
| /// Protocol buffer representing the configuration of a Saver. | |||
| /// </summary> | |||
| public sealed partial class SaverDef : pb::IMessage<SaverDef> { | |||
| private static readonly pb::MessageParser<SaverDef> _parser = new pb::MessageParser<SaverDef>(() => new SaverDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pb::MessageParser<SaverDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.SaverReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public SaverDef() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public SaverDef(SaverDef other) : this() { | |||
| filenameTensorName_ = other.filenameTensorName_; | |||
| saveTensorName_ = other.saveTensorName_; | |||
| restoreOpName_ = other.restoreOpName_; | |||
| maxToKeep_ = other.maxToKeep_; | |||
| sharded_ = other.sharded_; | |||
| keepCheckpointEveryNHours_ = other.keepCheckpointEveryNHours_; | |||
| version_ = other.version_; | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public SaverDef Clone() { | |||
| return new SaverDef(this); | |||
| } | |||
| /// <summary>Field number for the "filename_tensor_name" field.</summary> | |||
| public const int FilenameTensorNameFieldNumber = 1; | |||
| private string filenameTensorName_ = ""; | |||
| /// <summary> | |||
| /// The name of the tensor in which to specify the filename when saving or | |||
| /// restoring a model checkpoint. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public string FilenameTensorName { | |||
| get { return filenameTensorName_; } | |||
| set { | |||
| filenameTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
| } | |||
| } | |||
| /// <summary>Field number for the "save_tensor_name" field.</summary> | |||
| public const int SaveTensorNameFieldNumber = 2; | |||
| private string saveTensorName_ = ""; | |||
| /// <summary> | |||
| /// The operation to run when saving a model checkpoint. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public string SaveTensorName { | |||
| get { return saveTensorName_; } | |||
| set { | |||
| saveTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
| } | |||
| } | |||
| /// <summary>Field number for the "restore_op_name" field.</summary> | |||
| public const int RestoreOpNameFieldNumber = 3; | |||
| private string restoreOpName_ = ""; | |||
| /// <summary> | |||
| /// The operation to run when restoring a model checkpoint. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public string RestoreOpName { | |||
| get { return restoreOpName_; } | |||
| set { | |||
| restoreOpName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
| } | |||
| } | |||
| /// <summary>Field number for the "max_to_keep" field.</summary> | |||
| public const int MaxToKeepFieldNumber = 4; | |||
| private int maxToKeep_; | |||
| /// <summary> | |||
| /// Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int MaxToKeep { | |||
| get { return maxToKeep_; } | |||
| set { | |||
| maxToKeep_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "sharded" field.</summary> | |||
| public const int ShardedFieldNumber = 5; | |||
| private bool sharded_; | |||
| /// <summary> | |||
| /// Shard the save files, one per device that has Variable nodes. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool Sharded { | |||
| get { return sharded_; } | |||
| set { | |||
| sharded_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "keep_checkpoint_every_n_hours" field.</summary> | |||
| public const int KeepCheckpointEveryNHoursFieldNumber = 6; | |||
| private float keepCheckpointEveryNHours_; | |||
| /// <summary> | |||
| /// How often to keep an additional checkpoint. If not specified, only the last | |||
| /// "max_to_keep" checkpoints are kept; if specified, in addition to keeping | |||
| /// the last "max_to_keep" checkpoints, an additional checkpoint will be kept | |||
| /// for every n hours of training. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public float KeepCheckpointEveryNHours { | |||
| get { return keepCheckpointEveryNHours_; } | |||
| set { | |||
| keepCheckpointEveryNHours_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "version" field.</summary> | |||
| public const int VersionFieldNumber = 7; | |||
| private global::Tensorflow.SaverDef.Types.CheckpointFormatVersion version_ = 0; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public global::Tensorflow.SaverDef.Types.CheckpointFormatVersion Version { | |||
| get { return version_; } | |||
| set { | |||
| version_ = value; | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as SaverDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool Equals(SaverDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if (FilenameTensorName != other.FilenameTensorName) return false; | |||
| if (SaveTensorName != other.SaveTensorName) return false; | |||
| if (RestoreOpName != other.RestoreOpName) return false; | |||
| if (MaxToKeep != other.MaxToKeep) return false; | |||
| if (Sharded != other.Sharded) return false; | |||
| if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(KeepCheckpointEveryNHours, other.KeepCheckpointEveryNHours)) return false; | |||
| if (Version != other.Version) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (FilenameTensorName.Length != 0) hash ^= FilenameTensorName.GetHashCode(); | |||
| if (SaveTensorName.Length != 0) hash ^= SaveTensorName.GetHashCode(); | |||
| if (RestoreOpName.Length != 0) hash ^= RestoreOpName.GetHashCode(); | |||
| if (MaxToKeep != 0) hash ^= MaxToKeep.GetHashCode(); | |||
| if (Sharded != false) hash ^= Sharded.GetHashCode(); | |||
| if (KeepCheckpointEveryNHours != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(KeepCheckpointEveryNHours); | |||
| if (Version != 0) hash ^= Version.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| if (FilenameTensorName.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(FilenameTensorName); | |||
| } | |||
| if (SaveTensorName.Length != 0) { | |||
| output.WriteRawTag(18); | |||
| output.WriteString(SaveTensorName); | |||
| } | |||
| if (RestoreOpName.Length != 0) { | |||
| output.WriteRawTag(26); | |||
| output.WriteString(RestoreOpName); | |||
| } | |||
| if (MaxToKeep != 0) { | |||
| output.WriteRawTag(32); | |||
| output.WriteInt32(MaxToKeep); | |||
| } | |||
| if (Sharded != false) { | |||
| output.WriteRawTag(40); | |||
| output.WriteBool(Sharded); | |||
| } | |||
| if (KeepCheckpointEveryNHours != 0F) { | |||
| output.WriteRawTag(53); | |||
| output.WriteFloat(KeepCheckpointEveryNHours); | |||
| } | |||
| if (Version != 0) { | |||
| output.WriteRawTag(56); | |||
| output.WriteEnum((int) Version); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (FilenameTensorName.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(FilenameTensorName); | |||
| } | |||
| if (SaveTensorName.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(SaveTensorName); | |||
| } | |||
| if (RestoreOpName.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(RestoreOpName); | |||
| } | |||
| if (MaxToKeep != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxToKeep); | |||
| } | |||
| if (Sharded != false) { | |||
| size += 1 + 1; | |||
| } | |||
| if (KeepCheckpointEveryNHours != 0F) { | |||
| size += 1 + 4; | |||
| } | |||
| if (Version != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Version); | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(SaverDef other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| if (other.FilenameTensorName.Length != 0) { | |||
| FilenameTensorName = other.FilenameTensorName; | |||
| } | |||
| if (other.SaveTensorName.Length != 0) { | |||
| SaveTensorName = other.SaveTensorName; | |||
| } | |||
| if (other.RestoreOpName.Length != 0) { | |||
| RestoreOpName = other.RestoreOpName; | |||
| } | |||
| if (other.MaxToKeep != 0) { | |||
| MaxToKeep = other.MaxToKeep; | |||
| } | |||
| if (other.Sharded != false) { | |||
| Sharded = other.Sharded; | |||
| } | |||
| if (other.KeepCheckpointEveryNHours != 0F) { | |||
| KeepCheckpointEveryNHours = other.KeepCheckpointEveryNHours; | |||
| } | |||
| if (other.Version != 0) { | |||
| Version = other.Version; | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 10: { | |||
| FilenameTensorName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| SaveTensorName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 26: { | |||
| RestoreOpName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 32: { | |||
| MaxToKeep = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 40: { | |||
| Sharded = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 53: { | |||
| KeepCheckpointEveryNHours = input.ReadFloat(); | |||
| break; | |||
| } | |||
| case 56: { | |||
| version_ = (global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) input.ReadEnum(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #region Nested types | |||
| /// <summary>Container for nested types declared in the SaverDef message type.</summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static partial class Types { | |||
| /// <summary> | |||
| /// A version number that identifies a different on-disk checkpoint format. | |||
| /// Usually, each subclass of BaseSaverBuilder works with a particular | |||
| /// version/format. However, it is possible that the same builder may be | |||
| /// upgraded to support a newer checkpoint format in the future. | |||
| /// </summary> | |||
| public enum CheckpointFormatVersion { | |||
| /// <summary> | |||
| /// Internal legacy format. | |||
| /// </summary> | |||
| [pbr::OriginalName("LEGACY")] Legacy = 0, | |||
| /// <summary> | |||
| /// Deprecated format: tf.Saver() which works with tensorflow::table::Table. | |||
| /// </summary> | |||
| [pbr::OriginalName("V1")] V1 = 1, | |||
| /// <summary> | |||
| /// Current format: more efficient. | |||
| /// </summary> | |||
| [pbr::OriginalName("V2")] V2 = 2, | |||
| } | |||
| } | |||
| #endregion | |||
| } | |||
| #endregion | |||
| } | |||
| #endregion Designer generated code | |||
| @@ -15,6 +15,15 @@ namespace Tensorflow | |||
| Console.WriteLine(obj.ToString()); | |||
| } | |||
| public static T New<T>(object args) where T : IPyClass | |||
| { | |||
| var instance = Activator.CreateInstance<T>(); | |||
| instance.__init__(instance, args); | |||
| return instance; | |||
| } | |||
| public static void with(IPython py, Action<IPython> action) | |||
| { | |||
| try | |||
| @@ -63,7 +72,7 @@ namespace Tensorflow | |||
| catch (Exception ex) | |||
| { | |||
| Console.WriteLine(ex.ToString()); | |||
| throw ex; | |||
| return default(TOut); | |||
| } | |||
| finally | |||
| { | |||
| @@ -97,4 +106,9 @@ namespace Tensorflow | |||
| void __exit__(); | |||
| } | |||
| public class PyObject<T> where T : IPyClass | |||
| { | |||
| public T Instance { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,98 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class BaseSaverBuilder | |||
| { | |||
| protected int _write_version; | |||
| public BaseSaverBuilder(int write_version = 2) | |||
| { | |||
| _write_version = write_version; | |||
| } | |||
| public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables) | |||
| { | |||
| var tensor_names = new List<string>(); | |||
| var tensors = new List<Tensor>(); | |||
| var tensor_slices = new List<string>(); | |||
| foreach (var saveable in saveables) | |||
| { | |||
| foreach(var spec in saveable.specs) | |||
| { | |||
| tensor_names.Add(spec.name); | |||
| tensors.Add(spec.tensor); | |||
| tensor_slices.Add(spec.slice_spec); | |||
| } | |||
| } | |||
| if (_write_version == 2) | |||
| { | |||
| return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("_write_version v1"); | |||
| } | |||
| } | |||
| public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, | |||
| bool reshape = false, | |||
| bool sharded = false, | |||
| int max_to_keep = 5, | |||
| double keep_checkpoint_every_n_hours = 10000, | |||
| string name = "", | |||
| bool restore_sequentially = false, | |||
| string filename = "model", | |||
| bool build_save = true, | |||
| bool build_restore = true) | |||
| { | |||
| if (!build_save || !build_restore) | |||
| throw new ValueError("save and restore operations need to be built together " + | |||
| " when eager execution is not enabled."); | |||
| var saveables = saveable_object_util.validate_and_slice_inputs(names_to_saveables); | |||
| if (max_to_keep < 0) | |||
| max_to_keep = 0; | |||
| Python.with<ops.name_scope>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => | |||
| { | |||
| name = scope; | |||
| // Add a placeholder string tensor for the filename. | |||
| var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename"); | |||
| filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const"); | |||
| // Keep the name "Const" for backwards compatibility. | |||
| // Add the save ops. | |||
| if (sharded) | |||
| { | |||
| } | |||
| else | |||
| { | |||
| if (build_save) | |||
| _AddSaveOps(filename_tensor, saveables); | |||
| } | |||
| }); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | |||
| { | |||
| var save = save_op(filename_tensor, saveables); | |||
| return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder | |||
| { | |||
| public BulkSaverBuilder(int write_version = 2) : base(write_version) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,24 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public interface ISaverBuilder | |||
| { | |||
| Operation save_op(Tensor filename_tensor, SaveableObject[] saveables); | |||
| Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially); | |||
| SaverDef _build_internal(RefVariable[] names_to_saveables, | |||
| bool reshape = false, | |||
| bool sharded = false, | |||
| int max_to_keep = 5, | |||
| double keep_checkpoint_every_n_hours = 10000, | |||
| string name = "", | |||
| bool restore_sequentially = false, | |||
| string filename = "model", | |||
| bool build_save = true, | |||
| bool build_restore = true); | |||
| } | |||
| } | |||
| @@ -0,0 +1,19 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class ReferenceVariableSaveable : SaveableObject | |||
| { | |||
| private SaveSpec _spec; | |||
| public ReferenceVariableSaveable(Tensor var, string slice_spec, string name) | |||
| { | |||
| _spec = new SaveSpec(var, slice_spec, name, dtype: var.dtype); | |||
| op = var; | |||
| specs = new SaveSpec[] { _spec }; | |||
| this.name = name; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Class used to describe tensor slices that need to be saved. | |||
| /// </summary> | |||
| public class SaveSpec | |||
| { | |||
| private Tensor _tensor; | |||
| public Tensor tensor => _tensor; | |||
| private string _slice_spec; | |||
| public string slice_spec => _slice_spec; | |||
| private string _name; | |||
| public string name => _name; | |||
| private TF_DataType _dtype; | |||
| public TF_DataType dtype => _dtype; | |||
| public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| _tensor = tensor; | |||
| _slice_spec = slice_spec; | |||
| _name = name; | |||
| _dtype = dtype; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,31 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class SaveableObject | |||
| { | |||
| public Tensor op; | |||
| public SaveSpec[] specs; | |||
| public string name; | |||
| public string device; | |||
| public SaveableObject() | |||
| { | |||
| } | |||
| public SaveableObject(Tensor var, string slice_spec, string name) | |||
| { | |||
| } | |||
| public SaveableObject(Tensor op, SaveSpec[] specs, string name) | |||
| { | |||
| this.op = op; | |||
| this.specs = specs; | |||
| this.name = name; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,113 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Saves and restores variables. | |||
| /// </summary> | |||
| public class Saver | |||
| { | |||
| private RefVariable[] _var_list; | |||
| private bool _reshape; | |||
| private bool _sharded; | |||
| private int _max_to_keep; | |||
| private double _keep_checkpoint_every_n_hours; | |||
| private string _name; | |||
| private bool _restore_sequentially; | |||
| private SaverDef _saver_def; | |||
| private ISaverBuilder _builder; | |||
| private bool _allow_empty; | |||
| private bool _is_built; | |||
| private int _write_version; | |||
| private bool _pad_step_number; | |||
| private string _filename; | |||
| private bool _is_empty; | |||
| public Saver(RefVariable[] var_list = null, | |||
| bool reshape = false, | |||
| bool sharded = false, | |||
| int max_to_keep = 5, | |||
| double keep_checkpoint_every_n_hours = 10000, | |||
| string name = "", | |||
| bool restore_sequentially = false, | |||
| SaverDef saver_def = null, | |||
| ISaverBuilder builder = null, | |||
| bool defer_build = false, | |||
| bool allow_empty = false, | |||
| int write_version = 2, | |||
| bool pad_step_number = false, | |||
| bool save_relative_paths = false, | |||
| string filename = "") | |||
| { | |||
| _var_list = var_list; | |||
| _reshape = reshape; | |||
| _sharded = sharded; | |||
| _max_to_keep = max_to_keep; | |||
| _keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours; | |||
| _name = name; | |||
| _restore_sequentially = restore_sequentially; | |||
| _builder = builder; | |||
| _is_built = false; | |||
| _allow_empty = allow_empty; | |||
| _write_version = write_version; | |||
| _pad_step_number = pad_step_number; | |||
| if (!defer_build) | |||
| build(); | |||
| } | |||
| public void build() | |||
| { | |||
| _build(_filename, build_save: true, build_restore: true); | |||
| } | |||
| private void _build(string checkpoint_path, bool build_save, bool build_restore) | |||
| { | |||
| if (_is_built) return; | |||
| _is_built = true; | |||
| if (_saver_def == null) | |||
| { | |||
| if (_builder == null) | |||
| _builder = new BulkSaverBuilder(_write_version); | |||
| if (_var_list == null) | |||
| _var_list = variables._all_saveable_objects(); | |||
| if (_var_list == null || _var_list.Length == 0) | |||
| { | |||
| if (_allow_empty) | |||
| { | |||
| _is_empty = true; | |||
| return; | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError("No variables to save"); | |||
| } | |||
| } | |||
| _is_empty = false; | |||
| _saver_def = _builder._build_internal(_var_list, | |||
| reshape: _reshape, | |||
| sharded: _sharded, | |||
| max_to_keep: _max_to_keep, | |||
| keep_checkpoint_every_n_hours: _keep_checkpoint_every_n_hours, | |||
| name: _name, | |||
| restore_sequentially: _restore_sequentially, | |||
| filename: checkpoint_path, | |||
| build_save: build_save, | |||
| build_restore: build_restore); | |||
| } | |||
| else if (_saver_def != null && !string.IsNullOrEmpty(_name)) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,102 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class saveable_object_util | |||
| { | |||
| /// <summary> | |||
| /// Returns the variables and names that will be used for a Saver. | |||
| /// </summary> | |||
| /// <param name="names_to_saveables"></param> | |||
| /// <returns></returns> | |||
| public static SaveableObject[] validate_and_slice_inputs(RefVariable[] names_to_saveables) | |||
| { | |||
| var names_to_saveables_dict = op_list_to_dict(names_to_saveables); | |||
| var saveables = new List<SaveableObject>(); | |||
| var seen_ops = new List<Tensor>(); | |||
| foreach (var item in names_to_saveables_dict) | |||
| { | |||
| foreach (var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) | |||
| _add_saveable(saveables, seen_ops, converted_saveable_object); | |||
| } | |||
| return saveables.ToArray(); | |||
| } | |||
| private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : SaveableObject | |||
| { | |||
| if (seen_ops.Contains(saveable.op)) | |||
| throw new ValueError($"The same saveable will be restored with two names: {saveable.name}"); | |||
| saveables.Add(saveable); | |||
| seen_ops.Add(saveable.op); | |||
| } | |||
| /// <summary> | |||
| /// Create `SaveableObject`s from an operation. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static IEnumerable<SaveableObject> saveable_objects_for_op(Tensor op, string name) | |||
| { | |||
| if (false) | |||
| { | |||
| } | |||
| else | |||
| { | |||
| ops.init_scope(); | |||
| var variable = ops.internal_convert_to_tensor(op, as_ref: true); | |||
| if (variable.op.type == "VariableV2") | |||
| yield return new ReferenceVariableSaveable(variable, "", name); | |||
| } | |||
| } | |||
| public static Dictionary<string, Tensor> op_list_to_dict(RefVariable[] op_list, bool convert_variable_to_tensor = true) | |||
| { | |||
| op_list = op_list.OrderBy(x => x.name).ToArray(); | |||
| var names_to_saveables = new Dictionary<string, Tensor>(); | |||
| foreach(var var in op_list) | |||
| { | |||
| if (false) | |||
| { | |||
| throw new NotImplementedException("op_list_to_dict"); | |||
| } | |||
| else | |||
| { | |||
| if(false) // eager | |||
| { | |||
| } | |||
| else | |||
| { | |||
| string name = ""; | |||
| Tensor tensor = null; | |||
| if (convert_variable_to_tensor) | |||
| { | |||
| tensor = ops.internal_convert_to_tensor(var, as_ref: true); | |||
| } | |||
| if (var.op.type == "ReadVariableOp") | |||
| name = var.op.inputs[0].op.Name; | |||
| else | |||
| name = var.op.Name; | |||
| if (names_to_saveables.ContainsKey(name)) | |||
| throw new ValueError($"At least two variables have the same name: {name}"); | |||
| names_to_saveables[name] = tensor; | |||
| } | |||
| } | |||
| } | |||
| return names_to_saveables; | |||
| } | |||
| } | |||
| } | |||
| @@ -8,10 +8,9 @@ namespace Tensorflow | |||
| { | |||
| public static class train | |||
| { | |||
| public static Optimizer GradientDescentOptimizer(double learning_rate) | |||
| { | |||
| return new GradientDescentOptimizer(learning_rate); | |||
| } | |||
| public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); | |||
| public static Saver Saver() => new Saver(); | |||
| } | |||
| } | |||
| } | |||
| @@ -16,6 +16,26 @@ namespace Tensorflow | |||
| return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | |||
| } | |||
| /// <summary> | |||
| /// Returns all variables and `SaveableObject`s that must be checkpointed. | |||
| /// </summary> | |||
| /// <param name="scope"></param> | |||
| /// <returns></returns> | |||
| public static RefVariable[] _all_saveable_objects(string scope = "") | |||
| { | |||
| var all = new List<RefVariable>(); | |||
| var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||
| if(collection != null) | |||
| all.AddRange(collection as List<RefVariable>); | |||
| collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope); | |||
| if (collection != null) | |||
| all.AddRange(collection as List<RefVariable>); | |||
| return all.ToArray(); | |||
| } | |||
| /// <summary> | |||
| /// Returns global variables. | |||
| /// </summary> | |||
| @@ -27,6 +27,11 @@ namespace Tensorflow | |||
| /// Default collection for all variables, except local ones. | |||
| /// </summary> | |||
| public static string GLOBAL_VARIABLES = "variables"; | |||
| /// <summary> | |||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
| /// </summary> | |||
| public static string SAVEABLE_OBJECTS = "saveable_objects"; | |||
| } | |||
| } | |||
| } | |||
| @@ -387,6 +387,10 @@ namespace Tensorflow | |||
| { | |||
| case "Tensor": | |||
| return value as Tensor; | |||
| case "String": | |||
| return constant_op.constant(Convert.ToString(value), name); | |||
| case "String[]": | |||
| return constant_op.constant(value as string[], name); | |||
| case "Int32": | |||
| return constant_op.constant(Convert.ToInt32(value), name); | |||
| case "Double": | |||
| @@ -7,7 +7,7 @@ using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class TrainSaverTest | |||
| public class TrainSaverTest : Python | |||
| { | |||
| [TestMethod] | |||
| public void Save() | |||
| @@ -20,6 +20,14 @@ namespace TensorFlowNET.UnitTest | |||
| // Add an op to initialize the variables. | |||
| var init_op = tf.global_variables_initializer(); | |||
| // Add ops to save and restore all the variables. | |||
| var saver = tf.train.Saver(); | |||
| with<Session>(tf.Session(), sess => | |||
| { | |||
| sess.run(init_op); | |||
| }); | |||
| } | |||
| } | |||
| } | |||