diff --git a/docs/source/Train.md b/docs/source/Train.md
index c71b31c7..85d441ba 100644
--- a/docs/source/Train.md
+++ b/docs/source/Train.md
@@ -2,4 +2,11 @@
### Saver
-The `tf.train.saver` class provides methods to save and restore models.
\ No newline at end of file
+The `tf.train.saver` class provides methods to save and restore models.
+
+
+
+### Saver Builder
+
+##### Bulk Saver Builder
+
diff --git a/src/TensorFlowNET.Core/IPyClass.cs b/src/TensorFlowNET.Core/IPyClass.cs
new file mode 100644
index 00000000..fd08ab82
--- /dev/null
+++ b/src/TensorFlowNET.Core/IPyClass.cs
@@ -0,0 +1,21 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public interface IPyClass
+ {
+ ///
+ /// Called when the instance is created.
+ ///
+ ///
+ void __init__(IPyClass self, dynamic args);
+
+ void __enter__(IPyClass self);
+
+ void __exit__(IPyClass self);
+
+ void __del__(IPyClass self);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
index 24b39239..406dac9d 100644
--- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
+++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
@@ -116,6 +116,10 @@ namespace Tensorflow
values = new Tensor[] { keywords[input_name] as Tensor };
}
+ inputs.AddRange(values as Tensor[]);
+ base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype()));
+ input_types.AddRange(base_types);
+
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{
if (attrs.ContainsKey(input_arg.NumberAttr))
@@ -144,10 +148,32 @@ namespace Tensorflow
var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr);
}
}
+ else if (!string.IsNullOrEmpty(input_arg.TypeAttr))
+ {
+ var attr_value = base_types[0];
+ if (attrs.ContainsKey(input_arg.TypeAttr))
+ {
- inputs.AddRange(values as Tensor[]);
- base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype()));
- input_types.AddRange(base_types);
+ }
+ else
+ {
+ attrs[input_arg.TypeAttr] = attr_value;
+ inferred_from[input_arg.TypeAttr] = input_name;
+ }
+ }
+ else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
+ {
+ var attr_value = base_types;
+ if (attrs.ContainsKey(input_arg.TypeListAttr))
+ {
+
+ }
+ else
+ {
+ attrs[input_arg.TypeListAttr] = attr_value;
+ inferred_from[input_arg.TypeListAttr] = input_name;
+ }
+ }
}
// Process remaining attrs
@@ -213,6 +239,11 @@ namespace Tensorflow
case "type":
attr_value.Type = _MakeType((TF_DataType)value, attr_def);
break;
+ case "list(type)":
+ if (attr_value.List == null)
+ attr_value.List = new AttrValue.Types.ListValue();
+ attr_value.List.Type.AddRange((value as IList).Select(x => _MakeType(x, attr_def)));
+ break;
case "bool":
attr_value.B = (bool)value;
break;
@@ -225,9 +256,14 @@ namespace Tensorflow
throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}.");
break;
case "shape":
- attr_value.Shape = value == null ?
- attr_def.DefaultValue.Shape :
- tensor_util.as_shape((long[])value);
+ if (value == null && attr_def.DefaultValue != null)
+ attr_value.Shape = attr_def.DefaultValue.Shape;
+
+ if(value is TensorShape val1)
+ attr_value.Shape = val1.as_proto();
+ else if(value is long[] val2)
+ attr_value.Shape = tensor_util.as_shape(val2);
+
break;
default:
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index efc8e4e0..b3a3e607 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -106,6 +106,19 @@ namespace Tensorflow
throw new NotImplementedException("where");
}
+ ///
+ /// A placeholder op that passes through `input` when its output is not fed.
+ ///
+ /// The default value to produce when output is not fed.
+ ///
+ ///
+ ///
+ public static Tensor placeholder_with_default(T input, TensorShape shape, string name = "")
+ {
+ var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name });
+ return _op.outputs[0];
+ }
+
public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = "")
{
var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e });
diff --git a/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs
new file mode 100644
index 00000000..ce57e834
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs
@@ -0,0 +1,18 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class gen_io_ops
+ {
+ public static OpDefLibrary _op_def_lib = new OpDefLibrary();
+
+ public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = "")
+ {
+ var _op = _op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors });
+
+ return _op;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Protobuf/Saver.cs b/src/TensorFlowNET.Core/Protobuf/Saver.cs
new file mode 100644
index 00000000..e031f2f6
--- /dev/null
+++ b/src/TensorFlowNET.Core/Protobuf/Saver.cs
@@ -0,0 +1,401 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: saver.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace Tensorflow {
+
+ /// Holder for reflection information generated from saver.proto
+ public static partial class SaverReflection {
+
+ #region Descriptor
+ /// File descriptor for saver.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static SaverReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CgtzYXZlci5wcm90bxIKdGVuc29yZmxvdyKeAgoIU2F2ZXJEZWYSHAoUZmls",
+ "ZW5hbWVfdGVuc29yX25hbWUYASABKAkSGAoQc2F2ZV90ZW5zb3JfbmFtZRgC",
+ "IAEoCRIXCg9yZXN0b3JlX29wX25hbWUYAyABKAkSEwoLbWF4X3RvX2tlZXAY",
+ "BCABKAUSDwoHc2hhcmRlZBgFIAEoCBIlCh1rZWVwX2NoZWNrcG9pbnRfZXZl",
+ "cnlfbl9ob3VycxgGIAEoAhI9Cgd2ZXJzaW9uGAcgASgOMiwudGVuc29yZmxv",
+ "dy5TYXZlckRlZi5DaGVja3BvaW50Rm9ybWF0VmVyc2lvbiI1ChdDaGVja3Bv",
+ "aW50Rm9ybWF0VmVyc2lvbhIKCgZMRUdBQ1kQABIGCgJWMRABEgYKAlYyEAJC",
+ "ZQoTb3JnLnRlbnNvcmZsb3cudXRpbEILU2F2ZXJQcm90b3NQAVo8Z2l0aHVi",
+ "LmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3Jl",
+ "L3Byb3RvYnVm+AEBYgZwcm90bzM="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaverDef), global::Tensorflow.SaverDef.Parser, new[]{ "FilenameTensorName", "SaveTensorName", "RestoreOpName", "MaxToKeep", "Sharded", "KeepCheckpointEveryNHours", "Version" }, null, new[]{ typeof(global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) }, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ ///
+ /// Protocol buffer representing the configuration of a Saver.
+ ///
+ public sealed partial class SaverDef : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SaverDef());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.SaverReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public SaverDef() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public SaverDef(SaverDef other) : this() {
+ filenameTensorName_ = other.filenameTensorName_;
+ saveTensorName_ = other.saveTensorName_;
+ restoreOpName_ = other.restoreOpName_;
+ maxToKeep_ = other.maxToKeep_;
+ sharded_ = other.sharded_;
+ keepCheckpointEveryNHours_ = other.keepCheckpointEveryNHours_;
+ version_ = other.version_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public SaverDef Clone() {
+ return new SaverDef(this);
+ }
+
+ /// Field number for the "filename_tensor_name" field.
+ public const int FilenameTensorNameFieldNumber = 1;
+ private string filenameTensorName_ = "";
+ ///
+ /// The name of the tensor in which to specify the filename when saving or
+ /// restoring a model checkpoint.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string FilenameTensorName {
+ get { return filenameTensorName_; }
+ set {
+ filenameTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "save_tensor_name" field.
+ public const int SaveTensorNameFieldNumber = 2;
+ private string saveTensorName_ = "";
+ ///
+ /// The operation to run when saving a model checkpoint.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string SaveTensorName {
+ get { return saveTensorName_; }
+ set {
+ saveTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "restore_op_name" field.
+ public const int RestoreOpNameFieldNumber = 3;
+ private string restoreOpName_ = "";
+ ///
+ /// The operation to run when restoring a model checkpoint.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string RestoreOpName {
+ get { return restoreOpName_; }
+ set {
+ restoreOpName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "max_to_keep" field.
+ public const int MaxToKeepFieldNumber = 4;
+ private int maxToKeep_;
+ ///
+ /// Maximum number of checkpoints to keep. If 0, no checkpoints are deleted.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int MaxToKeep {
+ get { return maxToKeep_; }
+ set {
+ maxToKeep_ = value;
+ }
+ }
+
+ /// Field number for the "sharded" field.
+ public const int ShardedFieldNumber = 5;
+ private bool sharded_;
+ ///
+ /// Shard the save files, one per device that has Variable nodes.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Sharded {
+ get { return sharded_; }
+ set {
+ sharded_ = value;
+ }
+ }
+
+ /// Field number for the "keep_checkpoint_every_n_hours" field.
+ public const int KeepCheckpointEveryNHoursFieldNumber = 6;
+ private float keepCheckpointEveryNHours_;
+ ///
+ /// How often to keep an additional checkpoint. If not specified, only the last
+ /// "max_to_keep" checkpoints are kept; if specified, in addition to keeping
+ /// the last "max_to_keep" checkpoints, an additional checkpoint will be kept
+ /// for every n hours of training.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public float KeepCheckpointEveryNHours {
+ get { return keepCheckpointEveryNHours_; }
+ set {
+ keepCheckpointEveryNHours_ = value;
+ }
+ }
+
+ /// Field number for the "version" field.
+ public const int VersionFieldNumber = 7;
+ private global::Tensorflow.SaverDef.Types.CheckpointFormatVersion version_ = 0;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.SaverDef.Types.CheckpointFormatVersion Version {
+ get { return version_; }
+ set {
+ version_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as SaverDef);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(SaverDef other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (FilenameTensorName != other.FilenameTensorName) return false;
+ if (SaveTensorName != other.SaveTensorName) return false;
+ if (RestoreOpName != other.RestoreOpName) return false;
+ if (MaxToKeep != other.MaxToKeep) return false;
+ if (Sharded != other.Sharded) return false;
+ if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(KeepCheckpointEveryNHours, other.KeepCheckpointEveryNHours)) return false;
+ if (Version != other.Version) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (FilenameTensorName.Length != 0) hash ^= FilenameTensorName.GetHashCode();
+ if (SaveTensorName.Length != 0) hash ^= SaveTensorName.GetHashCode();
+ if (RestoreOpName.Length != 0) hash ^= RestoreOpName.GetHashCode();
+ if (MaxToKeep != 0) hash ^= MaxToKeep.GetHashCode();
+ if (Sharded != false) hash ^= Sharded.GetHashCode();
+ if (KeepCheckpointEveryNHours != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(KeepCheckpointEveryNHours);
+ if (Version != 0) hash ^= Version.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (FilenameTensorName.Length != 0) {
+ output.WriteRawTag(10);
+ output.WriteString(FilenameTensorName);
+ }
+ if (SaveTensorName.Length != 0) {
+ output.WriteRawTag(18);
+ output.WriteString(SaveTensorName);
+ }
+ if (RestoreOpName.Length != 0) {
+ output.WriteRawTag(26);
+ output.WriteString(RestoreOpName);
+ }
+ if (MaxToKeep != 0) {
+ output.WriteRawTag(32);
+ output.WriteInt32(MaxToKeep);
+ }
+ if (Sharded != false) {
+ output.WriteRawTag(40);
+ output.WriteBool(Sharded);
+ }
+ if (KeepCheckpointEveryNHours != 0F) {
+ output.WriteRawTag(53);
+ output.WriteFloat(KeepCheckpointEveryNHours);
+ }
+ if (Version != 0) {
+ output.WriteRawTag(56);
+ output.WriteEnum((int) Version);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (FilenameTensorName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(FilenameTensorName);
+ }
+ if (SaveTensorName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(SaveTensorName);
+ }
+ if (RestoreOpName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(RestoreOpName);
+ }
+ if (MaxToKeep != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxToKeep);
+ }
+ if (Sharded != false) {
+ size += 1 + 1;
+ }
+ if (KeepCheckpointEveryNHours != 0F) {
+ size += 1 + 4;
+ }
+ if (Version != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Version);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(SaverDef other) {
+ if (other == null) {
+ return;
+ }
+ if (other.FilenameTensorName.Length != 0) {
+ FilenameTensorName = other.FilenameTensorName;
+ }
+ if (other.SaveTensorName.Length != 0) {
+ SaveTensorName = other.SaveTensorName;
+ }
+ if (other.RestoreOpName.Length != 0) {
+ RestoreOpName = other.RestoreOpName;
+ }
+ if (other.MaxToKeep != 0) {
+ MaxToKeep = other.MaxToKeep;
+ }
+ if (other.Sharded != false) {
+ Sharded = other.Sharded;
+ }
+ if (other.KeepCheckpointEveryNHours != 0F) {
+ KeepCheckpointEveryNHours = other.KeepCheckpointEveryNHours;
+ }
+ if (other.Version != 0) {
+ Version = other.Version;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ FilenameTensorName = input.ReadString();
+ break;
+ }
+ case 18: {
+ SaveTensorName = input.ReadString();
+ break;
+ }
+ case 26: {
+ RestoreOpName = input.ReadString();
+ break;
+ }
+ case 32: {
+ MaxToKeep = input.ReadInt32();
+ break;
+ }
+ case 40: {
+ Sharded = input.ReadBool();
+ break;
+ }
+ case 53: {
+ KeepCheckpointEveryNHours = input.ReadFloat();
+ break;
+ }
+ case 56: {
+ version_ = (global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) input.ReadEnum();
+ break;
+ }
+ }
+ }
+ }
+
+ #region Nested types
+ /// Container for nested types declared in the SaverDef message type.
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static partial class Types {
+ ///
+ /// A version number that identifies a different on-disk checkpoint format.
+ /// Usually, each subclass of BaseSaverBuilder works with a particular
+ /// version/format. However, it is possible that the same builder may be
+ /// upgraded to support a newer checkpoint format in the future.
+ ///
+ public enum CheckpointFormatVersion {
+ ///
+ /// Internal legacy format.
+ ///
+ [pbr::OriginalName("LEGACY")] Legacy = 0,
+ ///
+ /// Deprecated format: tf.Saver() which works with tensorflow::table::Table.
+ ///
+ [pbr::OriginalName("V1")] V1 = 1,
+ ///
+ /// Current format: more efficient.
+ ///
+ [pbr::OriginalName("V2")] V2 = 2,
+ }
+
+ }
+ #endregion
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs
index d351b91e..a47ac262 100644
--- a/src/TensorFlowNET.Core/Python.cs
+++ b/src/TensorFlowNET.Core/Python.cs
@@ -15,6 +15,15 @@ namespace Tensorflow
Console.WriteLine(obj.ToString());
}
+ public static T New(object args) where T : IPyClass
+ {
+ var instance = Activator.CreateInstance();
+
+ instance.__init__(instance, args);
+
+ return instance;
+ }
+
public static void with(IPython py, Action action)
{
try
@@ -63,7 +72,7 @@ namespace Tensorflow
catch (Exception ex)
{
Console.WriteLine(ex.ToString());
- throw ex;
+ return default(TOut);
}
finally
{
@@ -97,4 +106,9 @@ namespace Tensorflow
void __exit__();
}
+
+ public class PyObject where T : IPyClass
+ {
+ public T Instance { get; set; }
+ }
}
diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
new file mode 100644
index 00000000..edbb8010
--- /dev/null
+++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
@@ -0,0 +1,98 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class BaseSaverBuilder
+ {
+ protected int _write_version;
+
+ public BaseSaverBuilder(int write_version = 2)
+ {
+ _write_version = write_version;
+ }
+
+ public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables)
+ {
+ var tensor_names = new List();
+ var tensors = new List();
+ var tensor_slices = new List();
+
+ foreach (var saveable in saveables)
+ {
+ foreach(var spec in saveable.specs)
+ {
+ tensor_names.Add(spec.name);
+ tensors.Add(spec.tensor);
+ tensor_slices.Add(spec.slice_spec);
+ }
+ }
+
+ if (_write_version == 2)
+ {
+ return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray());
+ }
+ else
+ {
+ throw new NotImplementedException("_write_version v1");
+ }
+ }
+
+ public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially)
+ {
+ throw new NotImplementedException();
+ }
+
+ public virtual SaverDef _build_internal(RefVariable[] names_to_saveables,
+ bool reshape = false,
+ bool sharded = false,
+ int max_to_keep = 5,
+ double keep_checkpoint_every_n_hours = 10000,
+ string name = "",
+ bool restore_sequentially = false,
+ string filename = "model",
+ bool build_save = true,
+ bool build_restore = true)
+ {
+ if (!build_save || !build_restore)
+ throw new ValueError("save and restore operations need to be built together " +
+ " when eager execution is not enabled.");
+
+ var saveables = saveable_object_util.validate_and_slice_inputs(names_to_saveables);
+
+ if (max_to_keep < 0)
+ max_to_keep = 0;
+
+ Python.with(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope =>
+ {
+ name = scope;
+
+ // Add a placeholder string tensor for the filename.
+ var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename");
+ filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const");
+ // Keep the name "Const" for backwards compatibility.
+
+ // Add the save ops.
+ if (sharded)
+ {
+
+ }
+ else
+ {
+ if (build_save)
+ _AddSaveOps(filename_tensor, saveables);
+ }
+ });
+
+ throw new NotImplementedException("");
+ }
+
+ public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables)
+ {
+ var save = save_op(filename_tensor, saveables);
+ return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs
new file mode 100644
index 00000000..b99b75f0
--- /dev/null
+++ b/src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs
@@ -0,0 +1,14 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder
+ {
+ public BulkSaverBuilder(int write_version = 2) : base(write_version)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs
new file mode 100644
index 00000000..ed69919e
--- /dev/null
+++ b/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs
@@ -0,0 +1,24 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public interface ISaverBuilder
+ {
+ Operation save_op(Tensor filename_tensor, SaveableObject[] saveables);
+
+ Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially);
+
+ SaverDef _build_internal(RefVariable[] names_to_saveables,
+ bool reshape = false,
+ bool sharded = false,
+ int max_to_keep = 5,
+ double keep_checkpoint_every_n_hours = 10000,
+ string name = "",
+ bool restore_sequentially = false,
+ string filename = "model",
+ bool build_save = true,
+ bool build_restore = true);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/Saving/ReferenceVariableSaveable.cs b/src/TensorFlowNET.Core/Train/Saving/ReferenceVariableSaveable.cs
new file mode 100644
index 00000000..583ef889
--- /dev/null
+++ b/src/TensorFlowNET.Core/Train/Saving/ReferenceVariableSaveable.cs
@@ -0,0 +1,19 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class ReferenceVariableSaveable : SaveableObject
+ {
+ private SaveSpec _spec;
+
+ public ReferenceVariableSaveable(Tensor var, string slice_spec, string name)
+ {
+ _spec = new SaveSpec(var, slice_spec, name, dtype: var.dtype);
+ op = var;
+ specs = new SaveSpec[] { _spec };
+ this.name = name;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/Saving/SaveSpec.cs b/src/TensorFlowNET.Core/Train/Saving/SaveSpec.cs
new file mode 100644
index 00000000..1e932209
--- /dev/null
+++ b/src/TensorFlowNET.Core/Train/Saving/SaveSpec.cs
@@ -0,0 +1,32 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ ///
+ /// Class used to describe tensor slices that need to be saved.
+ ///
+ public class SaveSpec
+ {
+ private Tensor _tensor;
+ public Tensor tensor => _tensor;
+
+ private string _slice_spec;
+ public string slice_spec => _slice_spec;
+
+ private string _name;
+ public string name => _name;
+
+ private TF_DataType _dtype;
+ public TF_DataType dtype => _dtype;
+
+ public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid)
+ {
+ _tensor = tensor;
+ _slice_spec = slice_spec;
+ _name = name;
+ _dtype = dtype;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs
new file mode 100644
index 00000000..79be269b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs
@@ -0,0 +1,31 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class SaveableObject
+ {
+ public Tensor op;
+ public SaveSpec[] specs;
+ public string name;
+ public string device;
+
+ public SaveableObject()
+ {
+
+ }
+
+ public SaveableObject(Tensor var, string slice_spec, string name)
+ {
+
+ }
+
+ public SaveableObject(Tensor op, SaveSpec[] specs, string name)
+ {
+ this.op = op;
+ this.specs = specs;
+ this.name = name;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs
new file mode 100644
index 00000000..5e7d6333
--- /dev/null
+++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs
@@ -0,0 +1,113 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ ///
+ /// Saves and restores variables.
+ ///
+ public class Saver
+ {
+ private RefVariable[] _var_list;
+ private bool _reshape;
+ private bool _sharded;
+ private int _max_to_keep;
+ private double _keep_checkpoint_every_n_hours;
+ private string _name;
+ private bool _restore_sequentially;
+ private SaverDef _saver_def;
+ private ISaverBuilder _builder;
+ private bool _allow_empty;
+ private bool _is_built;
+ private int _write_version;
+ private bool _pad_step_number;
+ private string _filename;
+ private bool _is_empty;
+
+ public Saver(RefVariable[] var_list = null,
+ bool reshape = false,
+ bool sharded = false,
+ int max_to_keep = 5,
+ double keep_checkpoint_every_n_hours = 10000,
+ string name = "",
+ bool restore_sequentially = false,
+ SaverDef saver_def = null,
+ ISaverBuilder builder = null,
+ bool defer_build = false,
+ bool allow_empty = false,
+ int write_version = 2,
+ bool pad_step_number = false,
+ bool save_relative_paths = false,
+ string filename = "")
+ {
+ _var_list = var_list;
+ _reshape = reshape;
+ _sharded = sharded;
+ _max_to_keep = max_to_keep;
+ _keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours;
+ _name = name;
+ _restore_sequentially = restore_sequentially;
+ _builder = builder;
+ _is_built = false;
+ _allow_empty = allow_empty;
+ _write_version = write_version;
+ _pad_step_number = pad_step_number;
+
+ if (!defer_build)
+ build();
+ }
+
+ public void build()
+ {
+ _build(_filename, build_save: true, build_restore: true);
+ }
+
+ private void _build(string checkpoint_path, bool build_save, bool build_restore)
+ {
+ if (_is_built) return;
+
+ _is_built = true;
+
+ if (_saver_def == null)
+ {
+ if (_builder == null)
+ _builder = new BulkSaverBuilder(_write_version);
+
+ if (_var_list == null)
+ _var_list = variables._all_saveable_objects();
+
+ if (_var_list == null || _var_list.Length == 0)
+ {
+ if (_allow_empty)
+ {
+ _is_empty = true;
+ return;
+ }
+ else
+ {
+ throw new ValueError("No variables to save");
+ }
+ }
+ _is_empty = false;
+
+ _saver_def = _builder._build_internal(_var_list,
+ reshape: _reshape,
+ sharded: _sharded,
+ max_to_keep: _max_to_keep,
+ keep_checkpoint_every_n_hours: _keep_checkpoint_every_n_hours,
+ name: _name,
+ restore_sequentially: _restore_sequentially,
+ filename: checkpoint_path,
+ build_save: build_save,
+ build_restore: build_restore);
+ }
+ else if (_saver_def != null && !string.IsNullOrEmpty(_name))
+ {
+ throw new NotImplementedException("");
+ }
+
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs
new file mode 100644
index 00000000..3a244e90
--- /dev/null
+++ b/src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs
@@ -0,0 +1,102 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class saveable_object_util
+ {
+ ///
+ /// Returns the variables and names that will be used for a Saver.
+ ///
+ ///
+ ///
+ public static SaveableObject[] validate_and_slice_inputs(RefVariable[] names_to_saveables)
+ {
+ var names_to_saveables_dict = op_list_to_dict(names_to_saveables);
+ var saveables = new List();
+ var seen_ops = new List();
+
+ foreach (var item in names_to_saveables_dict)
+ {
+ foreach (var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key))
+ _add_saveable(saveables, seen_ops, converted_saveable_object);
+ }
+ return saveables.ToArray();
+ }
+
+ private static void _add_saveable(List saveables, List seen_ops, T saveable) where T : SaveableObject
+ {
+ if (seen_ops.Contains(saveable.op))
+ throw new ValueError($"The same saveable will be restored with two names: {saveable.name}");
+
+ saveables.Add(saveable);
+ seen_ops.Add(saveable.op);
+ }
+
+ ///
+ /// Create `SaveableObject`s from an operation.
+ ///
+ ///
+ ///
+ ///
+ public static IEnumerable saveable_objects_for_op(Tensor op, string name)
+ {
+ if (false)
+ {
+
+ }
+ else
+ {
+ ops.init_scope();
+ var variable = ops.internal_convert_to_tensor(op, as_ref: true);
+ if (variable.op.type == "VariableV2")
+ yield return new ReferenceVariableSaveable(variable, "", name);
+ }
+ }
+
+ public static Dictionary op_list_to_dict(RefVariable[] op_list, bool convert_variable_to_tensor = true)
+ {
+ op_list = op_list.OrderBy(x => x.name).ToArray();
+ var names_to_saveables = new Dictionary();
+
+ foreach(var var in op_list)
+ {
+ if (false)
+ {
+ throw new NotImplementedException("op_list_to_dict");
+ }
+ else
+ {
+ if(false) // eager
+ {
+
+ }
+ else
+ {
+ string name = "";
+ Tensor tensor = null;
+
+ if (convert_variable_to_tensor)
+ {
+ tensor = ops.internal_convert_to_tensor(var, as_ref: true);
+ }
+
+ if (var.op.type == "ReadVariableOp")
+ name = var.op.inputs[0].op.Name;
+ else
+ name = var.op.Name;
+
+ if (names_to_saveables.ContainsKey(name))
+ throw new ValueError($"At least two variables have the same name: {name}");
+
+ names_to_saveables[name] = tensor;
+ }
+ }
+ }
+
+ return names_to_saveables;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/tf.optimizers.cs b/src/TensorFlowNET.Core/Train/tf.optimizers.cs
index 00fe846b..ba4fbea8 100644
--- a/src/TensorFlowNET.Core/Train/tf.optimizers.cs
+++ b/src/TensorFlowNET.Core/Train/tf.optimizers.cs
@@ -8,10 +8,9 @@ namespace Tensorflow
{
public static class train
{
- public static Optimizer GradientDescentOptimizer(double learning_rate)
- {
- return new GradientDescentOptimizer(learning_rate);
- }
+ public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate);
+
+ public static Saver Saver() => new Saver();
}
}
}
diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs
index 9cefa6fd..5cde1359 100644
--- a/src/TensorFlowNET.Core/Variables/variables.py.cs
+++ b/src/TensorFlowNET.Core/Variables/variables.py.cs
@@ -16,6 +16,26 @@ namespace Tensorflow
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES);
}
+ ///
+ /// Returns all variables and `SaveableObject`s that must be checkpointed.
+ ///
+ ///
+ ///
+ public static RefVariable[] _all_saveable_objects(string scope = "")
+ {
+ var all = new List();
+
+ var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
+ if(collection != null)
+ all.AddRange(collection as List);
+
+ collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope);
+ if (collection != null)
+ all.AddRange(collection as List);
+
+ return all.ToArray();
+ }
+
///
/// Returns global variables.
///
diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs
index cfc74aff..78e25bd8 100644
--- a/src/TensorFlowNET.Core/ops.GraphKeys.cs
+++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs
@@ -27,6 +27,11 @@ namespace Tensorflow
/// Default collection for all variables, except local ones.
///
public static string GLOBAL_VARIABLES = "variables";
+
+ ///
+ /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
+ ///
+ public static string SAVEABLE_OBJECTS = "saveable_objects";
}
}
}
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index 74c8a5f7..1fa09224 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -387,6 +387,10 @@ namespace Tensorflow
{
case "Tensor":
return value as Tensor;
+ case "String":
+ return constant_op.constant(Convert.ToString(value), name);
+ case "String[]":
+ return constant_op.constant(value as string[], name);
case "Int32":
return constant_op.constant(Convert.ToInt32(value), name);
case "Double":
diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs
index 99e7ee20..40775fcb 100644
--- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs
+++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs
@@ -7,7 +7,7 @@ using Tensorflow;
namespace TensorFlowNET.UnitTest
{
[TestClass]
- public class TrainSaverTest
+ public class TrainSaverTest : Python
{
[TestMethod]
public void Save()
@@ -20,6 +20,14 @@ namespace TensorFlowNET.UnitTest
// Add an op to initialize the variables.
var init_op = tf.global_variables_initializer();
+
+ // Add ops to save and restore all the variables.
+ var saver = tf.train.Saver();
+
+ with(tf.Session(), sess =>
+ {
+ sess.run(init_op);
+ });
}
}
}