| @@ -48,7 +48,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Google.Protobuf" Version="3.8.0" /> | <PackageReference Include="Google.Protobuf" Version="3.8.0" /> | ||||
| <PackageReference Include="NumSharp" Version="0.10.2" /> | |||||
| <PackageReference Include="NumSharp" Version="0.10.3" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -16,6 +16,12 @@ namespace Tensorflow | |||||
| _write_version = write_version; | _write_version = write_version; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Create an Op to save 'saveables'. | |||||
| /// </summary> | |||||
| /// <param name="filename_tensor"></param> | |||||
| /// <param name="saveables"></param> | |||||
| /// <returns></returns> | |||||
| public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables) | public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables) | ||||
| { | { | ||||
| var tensor_names = new List<string>(); | var tensor_names = new List<string>(); | ||||
| @@ -105,6 +111,10 @@ namespace Tensorflow | |||||
| } | } | ||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| // Do some sanity checking on collections containing | |||||
| // PartitionedVariables. If a saved collection has a PartitionedVariable, | |||||
| // the GraphDef needs to include concat ops to get the value (or there'll | |||||
| // be a lookup error on load). | |||||
| var check_collection_list = graph.get_all_collection_keys(); | var check_collection_list = graph.get_all_collection_keys(); | ||||
| foreach (var collection_type in check_collection_list) | foreach (var collection_type in check_collection_list) | ||||
| { | { | ||||
| @@ -158,7 +158,10 @@ namespace Tensorflow | |||||
| string model_checkpoint_path = ""; | string model_checkpoint_path = ""; | ||||
| string checkpoint_file = ""; | string checkpoint_file = ""; | ||||
| checkpoint_file = $"{save_path}-{global_step}"; | |||||
| if (global_step > 0) | |||||
| checkpoint_file = $"{save_path}-{global_step}"; | |||||
| else | |||||
| checkpoint_file = save_path; | |||||
| var save_path_parent = Path.GetDirectoryName(save_path); | var save_path_parent = Path.GetDirectoryName(save_path); | ||||
| @@ -291,15 +294,13 @@ namespace Tensorflow | |||||
| if (_saver_def.MaxToKeep <= 0) return; | if (_saver_def.MaxToKeep <= 0) return; | ||||
| // Remove first from list if the same name was used before. | // Remove first from list if the same name was used before. | ||||
| foreach (var p in _last_checkpoints) | |||||
| if (latest_save_path == _CheckpointFilename((p.Key, p.Value))) | |||||
| _last_checkpoints.Remove(p.Key); | |||||
| // Append new path to list | |||||
| _last_checkpoints.Add(latest_save_path, Python.time()); | |||||
| var _existed_checkpoints = _last_checkpoints.FirstOrDefault(p => latest_save_path == _CheckpointFilename((p.Key, p.Value))); | |||||
| if (_existed_checkpoints.Key != null) | |||||
| _last_checkpoints.Remove(_existed_checkpoints.Key); | |||||
| _last_checkpoints.Add(latest_save_path, time()); | |||||
| // If more than max_to_keep, remove oldest. | // If more than max_to_keep, remove oldest. | ||||
| if(_last_checkpoints.Count > _saver_def.MaxToKeep) | |||||
| if (_last_checkpoints.Count > _saver_def.MaxToKeep) | |||||
| { | { | ||||
| var first = _last_checkpoints.First(); | var first = _last_checkpoints.First(); | ||||
| _last_checkpoints.Remove(first.Key); | _last_checkpoints.Remove(first.Key); | ||||
| @@ -25,7 +25,7 @@ namespace Tensorflow | |||||
| var saver = _create_saver_from_imported_meta_graph( | var saver = _create_saver_from_imported_meta_graph( | ||||
| meta_graph_def, import_scope, imported_vars); | meta_graph_def, import_scope, imported_vars); | ||||
| return (saver, null); | |||||
| return (saver, imported_return_elements); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -26,7 +26,7 @@ | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.0" /> | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.1" /> | |||||
| <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -105,6 +105,8 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
| // Create a train saver that is used to restore values into an eval graph | // Create a train saver that is used to restore values into an eval graph | ||||
| // when exporting models. | // when exporting models. | ||||
| var train_saver = tf.train.Saver(); | var train_saver = tf.train.Saver(); | ||||
| train_saver.save(sess, CHECKPOINT_NAME); | |||||
| sw.Restart(); | sw.Restart(); | ||||
| for (int i = 0; i < how_many_training_steps; i++) | for (int i = 0; i < how_many_training_steps; i++) | ||||
| @@ -17,6 +17,7 @@ | |||||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | ||||
| <PackageReference Include="SharpZipLib" Version="1.1.0" /> | <PackageReference Include="SharpZipLib" Version="1.1.0" /> | ||||
| <PackageReference Include="System.Drawing.Common" Version="4.5.1" /> | <PackageReference Include="System.Drawing.Common" Version="4.5.1" /> | ||||
| <PackageReference Include="TensorFlow.NET" Version="0.8.0" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -16,7 +16,7 @@ | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" /> | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.1" /> | |||||
| <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||