| @@ -23,6 +23,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", | |||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | ||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorflowNET.Hub\Tensorflow.Hub.csproj", "{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}" | |||||
| EndProject | |||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub.Unittest", "test\TensorflowNET.Hub.Unittest\Tensorflow.Hub.Unittest.csproj", "{7DEA8760-E401-4872-81F3-405F185A13A0}" | |||||
| EndProject | |||||
| Global | Global | ||||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
| Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
| @@ -153,6 +157,30 @@ Global | |||||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 | {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 | ||||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | ||||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | ||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.ActiveCfg = Debug|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.Build.0 = Debug|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.Build.0 = Release|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.ActiveCfg = Release|Any CPU | |||||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.Build.0 = Release|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.ActiveCfg = Debug|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.Build.0 = Debug|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.Build.0 = Release|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.ActiveCfg = Release|Any CPU | |||||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.Build.0 = Release|Any CPU | |||||
| EndGlobalSection | EndGlobalSection | ||||
| GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
| HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
| @@ -207,9 +207,24 @@ namespace Tensorflow | |||||
| } | } | ||||
| public override string ToString() | public override string ToString() | ||||
| => items.Count() == 1 | |||||
| ? items.First().ToString() | |||||
| : items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); | |||||
| { | |||||
| if(items.Count == 1) | |||||
| { | |||||
| return items[0].ToString(); | |||||
| } | |||||
| else | |||||
| { | |||||
| StringBuilder sb = new StringBuilder(); | |||||
| sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); | |||||
| for(int i = 0; i < items.Count; i++) | |||||
| { | |||||
| var tensor = items[i]; | |||||
| sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); | |||||
| } | |||||
| sb.Append("]\n"); | |||||
| return sb.ToString(); | |||||
| } | |||||
| } | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| @@ -301,6 +301,17 @@ namespace Tensorflow | |||||
| type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; | type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; | ||||
| } | } | ||||
| public static bool is_unsigned(this TF_DataType type) | |||||
| { | |||||
| return type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || | |||||
| type == TF_DataType.TF_UINT64; | |||||
| } | |||||
| public static bool is_bool(this TF_DataType type) | |||||
| { | |||||
| return type == TF_DataType.TF_BOOL; | |||||
| } | |||||
| public static bool is_floating(this TF_DataType type) | public static bool is_floating(this TF_DataType type) | ||||
| { | { | ||||
| return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE; | return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE; | ||||
| @@ -22,9 +22,9 @@ namespace Tensorflow.Keras.Engine | |||||
| // If dtype is DT_FLOAT, provide a uniform unit scaling initializer | // If dtype is DT_FLOAT, provide a uniform unit scaling initializer | ||||
| if (dtype.is_floating()) | if (dtype.is_floating()) | ||||
| initializer = tf.glorot_uniform_initializer; | initializer = tf.glorot_uniform_initializer; | ||||
| else if (dtype.is_integer()) | |||||
| else if (dtype.is_integer() || dtype.is_unsigned() || dtype.is_bool()) | |||||
| initializer = tf.zeros_initializer; | initializer = tf.zeros_initializer; | ||||
| else | |||||
| else if(getter is null) | |||||
| throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | ||||
| } | } | ||||
| @@ -36,5 +36,9 @@ namespace Tensorflow.Keras.Saving | |||||
| public bool? Stateful { get; set; } | public bool? Stateful { get; set; } | ||||
| [JsonProperty("model_config")] | [JsonProperty("model_config")] | ||||
| public KerasModelConfig? ModelConfig { get; set; } | public KerasModelConfig? ModelConfig { get; set; } | ||||
| [JsonProperty("sparse")] | |||||
| public bool Sparse { get; set; } | |||||
| [JsonProperty("ragged")] | |||||
| public bool Ragged { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -401,13 +401,22 @@ namespace Tensorflow.Keras.Saving | |||||
| private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata) | private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata) | ||||
| { | { | ||||
| if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||||
| if (identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||||
| { | { | ||||
| return RevivedLayer.init_from_metadata(metadata); | return RevivedLayer.init_from_metadata(metadata); | ||||
| } | } | ||||
| else if(identifier == SavedModel.Constants.MODEL_IDENTIFIER || identifier == SavedModel.Constants.SEQUENTIAL_IDENTIFIER | |||||
| || identifier == SavedModel.Constants.NETWORK_IDENTIFIER) | |||||
| { | |||||
| return RevivedNetwork.init_from_metadata(metadata); | |||||
| } | |||||
| else if(identifier == SavedModel.Constants.INPUT_LAYER_IDENTIFIER) | |||||
| { | |||||
| return RevivedInputLayer.init_from_metadata(metadata); | |||||
| } | |||||
| else | else | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| throw new ValueError($"Cannot revive the layer {identifier}."); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,15 +1,46 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Layers; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel | namespace Tensorflow.Keras.Saving.SavedModel | ||||
| { | { | ||||
| public class RevivedInputLayer: Layer | |||||
| public class RevivedInputLayer: InputLayer | |||||
| { | { | ||||
| private RevivedInputLayer(): base(null) | |||||
| protected RevivedConfig _config = null; | |||||
| private RevivedInputLayer(InputLayerArgs args): base(args) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override IKerasConfig get_config() | |||||
| { | |||||
| return _config; | |||||
| } | |||||
| public static (RevivedInputLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||||
| { | |||||
| InputLayerArgs args = new InputLayerArgs() | |||||
| { | |||||
| Name = metadata.Name, | |||||
| DType = metadata.DType, | |||||
| Sparse = metadata.Sparse, | |||||
| Ragged = metadata.Ragged, | |||||
| BatchInputShape = metadata.BatchInputShape | |||||
| }; | |||||
| RevivedInputLayer revived_obj = new RevivedInputLayer(args); | |||||
| revived_obj._config = new RevivedConfig() { Config = metadata.Config }; | |||||
| return (revived_obj, Loader.setattr); | |||||
| } | |||||
| public override string ToString() | |||||
| { | |||||
| return $"Customized keras input layer: {Name}."; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| return (revived_obj, ReviveUtils._revive_setter); | return (revived_obj, ReviveUtils._revive_setter); | ||||
| } | } | ||||
| private RevivedConfig _config = null; | |||||
| protected RevivedConfig _config = null; | |||||
| public object keras_api | public object keras_api | ||||
| { | { | ||||
| @@ -70,7 +70,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| } | } | ||||
| } | } | ||||
| public RevivedLayer(LayerArgs args): base(args) | |||||
| protected RevivedLayer(LayerArgs args): base(args) | |||||
| { | { | ||||
| } | } | ||||
| @@ -84,17 +84,5 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | { | ||||
| return _config; | return _config; | ||||
| } | } | ||||
| //protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
| //{ | |||||
| // if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||||
| // { | |||||
| // return base.Call(inputs, state, training); | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // return (func as Function).Apply(inputs); | |||||
| // } | |||||
| //} | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,40 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Utils; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | |||||
| public class RevivedNetwork: RevivedLayer | |||||
| { | |||||
| private RevivedNetwork(LayerArgs args) : base(args) | |||||
| { | |||||
| } | |||||
| public static (RevivedNetwork, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||||
| { | |||||
| RevivedNetwork revived_obj = new(new LayerArgs() { Name = metadata.Name }); | |||||
| // TODO(Rinne): with utils.no_automatic_dependency_tracking_scope(revived_obj) | |||||
| // TODO(Rinne): revived_obj._expects_training_arg | |||||
| var config = metadata.Config; | |||||
| if (generic_utils.validate_config(config)) | |||||
| { | |||||
| revived_obj._config = new RevivedConfig() { Config = config }; | |||||
| } | |||||
| if(metadata.ActivityRegularizer is not null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| return (revived_obj, ReviveUtils._revive_setter); | |||||
| } | |||||
| public override string ToString() | |||||
| { | |||||
| return $"Customized keras Network: {Name}."; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,57 @@ | |||||
| using System.IO; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public class GcsCompressedFileResolver : IResolver | |||||
| { | |||||
| const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; | |||||
| public string Call(string handle) | |||||
| { | |||||
| var module_dir = _module_dir(handle); | |||||
| return resolver.atomic_download_async(handle, download, module_dir, LOCK_FILE_TIMEOUT_SEC) | |||||
| .GetAwaiter().GetResult(); | |||||
| } | |||||
| public bool IsSupported(string handle) | |||||
| { | |||||
| return handle.StartsWith("gs://") && _is_tarfile(handle); | |||||
| } | |||||
| private async Task download(string handle, string tmp_dir) | |||||
| { | |||||
| new resolver.DownloadManager(handle).download_and_uncompress( | |||||
| new FileStream(handle, FileMode.Open, FileAccess.Read), tmp_dir); | |||||
| await Task.Run(() => { }); | |||||
| } | |||||
| private static string _module_dir(string handle) | |||||
| { | |||||
| var cache_dir = resolver.tfhub_cache_dir(use_temp: true); | |||||
| var sha1 = ComputeSha1(handle); | |||||
| return resolver.create_local_module_dir(cache_dir, sha1); | |||||
| } | |||||
| private static bool _is_tarfile(string filename) | |||||
| { | |||||
| return filename.EndsWith(".tar") || filename.EndsWith(".tar.gz") || filename.EndsWith(".tgz"); | |||||
| } | |||||
| private static string ComputeSha1(string s) | |||||
| { | |||||
| using (var sha = new System.Security.Cryptography.SHA1Managed()) | |||||
| { | |||||
| var bytes = System.Text.Encoding.UTF8.GetBytes(s); | |||||
| var hash = sha.ComputeHash(bytes); | |||||
| var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); | |||||
| foreach (var b in hash) | |||||
| { | |||||
| stringBuilder.Append(b.ToString("x2")); | |||||
| } | |||||
| return stringBuilder.ToString(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,78 @@ | |||||
| using System; | |||||
| using System.Net.Http; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public class HttpCompressedFileResolver : HttpResolverBase | |||||
| { | |||||
| const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; // 10 minutes | |||||
| private static readonly (string, string) _COMPRESSED_FORMAT_QUERY = | |||||
| ("tf-hub-format", "compressed"); | |||||
| private static string _module_dir(string handle) | |||||
| { | |||||
| var cache_dir = resolver.tfhub_cache_dir(use_temp: true); | |||||
| var sha1 = ComputeSha1(handle); | |||||
| return resolver.create_local_module_dir(cache_dir, sha1); | |||||
| } | |||||
| public override bool IsSupported(string handle) | |||||
| { | |||||
| if (!is_http_protocol(handle)) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| var load_format = resolver.model_load_format(); | |||||
| return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.COMPRESSED) | |||||
| || load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.AUTO); | |||||
| } | |||||
| public override string Call(string handle) | |||||
| { | |||||
| var module_dir = _module_dir(handle); | |||||
| return resolver.atomic_download_async( | |||||
| handle, | |||||
| download, | |||||
| module_dir, | |||||
| LOCK_FILE_TIMEOUT_SEC | |||||
| ).GetAwaiter().GetResult(); | |||||
| } | |||||
| private async Task download(string handle, string tmp_dir) | |||||
| { | |||||
| var client = new HttpClient(); | |||||
| var response = await client.GetAsync(_append_compressed_format_query(handle)); | |||||
| using (var httpStream = await response.Content.ReadAsStreamAsync()) | |||||
| { | |||||
| new resolver.DownloadManager(handle).download_and_uncompress(httpStream, tmp_dir); | |||||
| } | |||||
| } | |||||
| private string _append_compressed_format_query(string handle) | |||||
| { | |||||
| return append_format_query(handle, _COMPRESSED_FORMAT_QUERY); | |||||
| } | |||||
| private static string ComputeSha1(string s) | |||||
| { | |||||
| using (var sha = new System.Security.Cryptography.SHA1Managed()) | |||||
| { | |||||
| var bytes = System.Text.Encoding.UTF8.GetBytes(s); | |||||
| var hash = sha.ComputeHash(bytes); | |||||
| var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); | |||||
| foreach (var b in hash) | |||||
| { | |||||
| stringBuilder.Append(b.ToString("x2")); | |||||
| } | |||||
| return stringBuilder.ToString(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,65 @@ | |||||
| using System; | |||||
| using System.Net; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public class HttpUncompressedFileResolver : HttpResolverBase | |||||
| { | |||||
| private readonly PathResolver _pathResolver; | |||||
| public HttpUncompressedFileResolver() | |||||
| { | |||||
| _pathResolver = new PathResolver(); | |||||
| } | |||||
| public override string Call(string handle) | |||||
| { | |||||
| handle = AppendUncompressedFormatQuery(handle); | |||||
| var gsLocation = RequestGcsLocation(handle); | |||||
| return _pathResolver.Call(gsLocation); | |||||
| } | |||||
| public override bool IsSupported(string handle) | |||||
| { | |||||
| if (!is_http_protocol(handle)) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| var load_format = resolver.model_load_format(); | |||||
| return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.UNCOMPRESSED); | |||||
| } | |||||
| protected virtual string AppendUncompressedFormatQuery(string handle) | |||||
| { | |||||
| return append_format_query(handle, ("tf-hub-format", "uncompressed")); | |||||
| } | |||||
| protected virtual string RequestGcsLocation(string handleWithParams) | |||||
| { | |||||
| var request = WebRequest.Create(handleWithParams); | |||||
| var response = request.GetResponse() as HttpWebResponse; | |||||
| if (response == null) | |||||
| { | |||||
| throw new Exception("Failed to get a response from the server."); | |||||
| } | |||||
| var statusCode = (int)response.StatusCode; | |||||
| if (statusCode != 303) | |||||
| { | |||||
| throw new Exception($"Expected 303 for GCS location lookup but got HTTP {statusCode} {response.StatusDescription}"); | |||||
| } | |||||
| var location = response.Headers["Location"]; | |||||
| if (!location.StartsWith("gs://")) | |||||
| { | |||||
| throw new Exception($"Expected Location:GS path but received {location}"); | |||||
| } | |||||
| return location; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,157 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| using Tensorflow.Training.Saving.SavedModel; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public class KerasLayer : Layer | |||||
| { | |||||
| private string _handle; | |||||
| private LoadOptions? _load_options; | |||||
| private Trackable _func; | |||||
| private Func<Tensors, Tensors> _callable; | |||||
| public KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) : | |||||
| base(new Keras.ArgsDefinition.LayerArgs() { Trainable = trainable }) | |||||
| { | |||||
| _handle = handle; | |||||
| _load_options = load_options; | |||||
| _func = load_module(_handle, _load_options); | |||||
| _track_trackable(_func, "_func"); | |||||
| // TODO(Rinne): deal with _is_hub_module_v1. | |||||
| _callable = _get_callable(); | |||||
| _setup_layer(trainable); | |||||
| } | |||||
| private void _setup_layer(bool trainable = false) | |||||
| { | |||||
| HashSet<string> trainable_variables; | |||||
| if (_func is Layer layer) | |||||
| { | |||||
| foreach (var v in layer.TrainableVariables) | |||||
| { | |||||
| _add_existing_weight(v, true); | |||||
| } | |||||
| trainable_variables = new HashSet<string>(layer.TrainableVariables.Select(v => v.UniqueId)); | |||||
| } | |||||
| else if (_func.CustomizedFields.TryGetValue("trainable_variables", out var obj) && obj is IEnumerable<Trackable> trackables) | |||||
| { | |||||
| foreach (var trackable in trackables) | |||||
| { | |||||
| if (trackable is IVariableV1 v) | |||||
| { | |||||
| _add_existing_weight(v, true); | |||||
| } | |||||
| } | |||||
| trainable_variables = new HashSet<string>(trackables.Where(t => t is IVariableV1).Select(t => (t as IVariableV1).UniqueId)); | |||||
| } | |||||
| else | |||||
| { | |||||
| trainable_variables = new HashSet<string>(); | |||||
| } | |||||
| if (_func is Layer) | |||||
| { | |||||
| layer = (Layer)_func; | |||||
| foreach (var v in layer.Variables) | |||||
| { | |||||
| if (!trainable_variables.Contains(v.UniqueId)) | |||||
| { | |||||
| _add_existing_weight(v, false); | |||||
| } | |||||
| } | |||||
| } | |||||
| else if (_func.CustomizedFields.TryGetValue("variables", out var obj) && obj is IEnumerable<Trackable> total_trackables) | |||||
| { | |||||
| foreach (var trackable in total_trackables) | |||||
| { | |||||
| if (trackable is IVariableV1 v && !trainable_variables.Contains(v.UniqueId)) | |||||
| { | |||||
| _add_existing_weight(v, false); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (_func.CustomizedFields.ContainsKey("regularization_losses")) | |||||
| { | |||||
| if ((_func.CustomizedFields["regularization_losses"] as ListWrapper)?.Count > 0) | |||||
| { | |||||
| throw new NotImplementedException("The regularization_losses loading has not been supported yet, " + | |||||
| "please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues to let us know and add a feature."); | |||||
| } | |||||
| } | |||||
| } | |||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
| { | |||||
| _check_trainability(); | |||||
| // TODO(Rinne): deal with training_argument | |||||
| var result = _callable(inputs); | |||||
| return _apply_output_shape_if_set(inputs, result); | |||||
| } | |||||
| private void _check_trainability() | |||||
| { | |||||
| if (!Trainable) return; | |||||
| // TODO(Rinne): deal with _is_hub_module_v1 and signature | |||||
| if (TrainableWeights is null || TrainableWeights.Count == 0) | |||||
| { | |||||
| tf.Logger.Error("hub.KerasLayer is trainable but has zero trainable weights."); | |||||
| } | |||||
| } | |||||
| private Tensors _apply_output_shape_if_set(Tensors inputs, Tensors result) | |||||
| { | |||||
| // TODO(Rinne): implement it. | |||||
| return result; | |||||
| } | |||||
| private void _add_existing_weight(IVariableV1 weight, bool? trainable = null) | |||||
| { | |||||
| bool is_trainable; | |||||
| if (trainable is null) | |||||
| { | |||||
| is_trainable = weight.Trainable; | |||||
| } | |||||
| else | |||||
| { | |||||
| is_trainable = trainable.Value; | |||||
| } | |||||
| add_weight(weight.Name, weight.shape, weight.dtype, trainable: is_trainable, getter: x => weight); | |||||
| } | |||||
| private Func<Tensors, Tensors> _get_callable() | |||||
| { | |||||
| if (_func is Layer layer) | |||||
| { | |||||
| return x => layer.Apply(x); | |||||
| } | |||||
| if (_func.CustomizedFields.ContainsKey("__call__")) | |||||
| { | |||||
| if (_func.CustomizedFields["__call__"] is RestoredFunction function) | |||||
| { | |||||
| return x => function.Apply(x); | |||||
| } | |||||
| } | |||||
| throw new ValueError("Cannot get the callable from the model."); | |||||
| } | |||||
| private static Trackable load_module(string handle, LoadOptions? load_options = null) | |||||
| { | |||||
| //var set_load_options = load_options ?? LoadContext.get_load_option(); | |||||
| return module_v2.load(handle, load_options); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,17 @@ | |||||
| <Project Sdk="Microsoft.NET.Sdk"> | |||||
| <PropertyGroup> | |||||
| <TargetFrameworks>netstandard2.0;net6;net7</TargetFrameworks> | |||||
| <LangVersion>11</LangVersion> | |||||
| <Nullable>enable</Nullable> | |||||
| </PropertyGroup> | |||||
| <ItemGroup> | |||||
| <PackageReference Include="SharpCompress" Version="0.33.0" /> | |||||
| </ItemGroup> | |||||
| <ItemGroup> | |||||
| <ProjectReference Include="..\\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> | |||||
| </ItemGroup> | |||||
| </Project> | |||||
| @@ -0,0 +1,74 @@ | |||||
| using SharpCompress.Common; | |||||
| using SharpCompress.Readers; | |||||
| using System; | |||||
| using System.IO; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| internal static class file_utils | |||||
| { | |||||
| //public static void extract_file(TarInputStream tgz, TarEntry tarInfo, string dstPath, uint bufferSize = 10 << 20, Action<long> logFunction = null) | |||||
| //{ | |||||
| // using (var src = tgz.GetNextEntry() == tarInfo ? tgz : null) | |||||
| // { | |||||
| // if (src is null) | |||||
| // { | |||||
| // return; | |||||
| // } | |||||
| // using (var dst = File.Create(dstPath)) | |||||
| // { | |||||
| // var buffer = new byte[bufferSize]; | |||||
| // int count; | |||||
| // while ((count = src.Read(buffer, 0, buffer.Length)) > 0) | |||||
| // { | |||||
| // dst.Write(buffer, 0, count); | |||||
| // logFunction?.Invoke(count); | |||||
| // } | |||||
| // } | |||||
| // } | |||||
| //} | |||||
| public static void extract_tarfile_to_destination(Stream fileobj, string dst_path, Action<long> logFunction = null) | |||||
| { | |||||
| using (IReader reader = ReaderFactory.Open(fileobj)) | |||||
| { | |||||
| while (reader.MoveToNextEntry()) | |||||
| { | |||||
| if (!reader.Entry.IsDirectory) | |||||
| { | |||||
| reader.WriteEntryToDirectory( | |||||
| dst_path, | |||||
| new ExtractionOptions() { ExtractFullPath = true, Overwrite = true } | |||||
| ); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| public static string merge_relative_path(string dstPath, string relPath) | |||||
| { | |||||
| var cleanRelPath = Path.GetFullPath(relPath).TrimStart('/', '\\'); | |||||
| if (cleanRelPath == ".") | |||||
| { | |||||
| return dstPath; | |||||
| } | |||||
| if (cleanRelPath.StartsWith("..") || Path.IsPathRooted(cleanRelPath)) | |||||
| { | |||||
| throw new InvalidDataException($"Relative path '{relPath}' is invalid."); | |||||
| } | |||||
| var merged = Path.Combine(dstPath, cleanRelPath); | |||||
| if (!merged.StartsWith(dstPath)) | |||||
| { | |||||
| throw new InvalidDataException($"Relative path '{relPath}' is invalid. Failed to merge with '{dstPath}'."); | |||||
| } | |||||
| return merged; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,17 @@ | |||||
| using Tensorflow.Hub; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public static class HubAPI | |||||
| { | |||||
| public static HubMethods hub { get; } = new HubMethods(); | |||||
| } | |||||
| public class HubMethods | |||||
| { | |||||
| public KerasLayer KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) | |||||
| { | |||||
| return new KerasLayer(handle, trainable, load_options); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,33 @@ | |||||
| using System.IO; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| internal static class module_v2 | |||||
| { | |||||
| public static Trackable load(string handle, LoadOptions? options) | |||||
| { | |||||
| var module_path = resolve(handle); | |||||
| // TODO(Rinne): deal with is_hub_module_v1 | |||||
| var saved_model_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PB); | |||||
| var saved_model_pb_txt_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PBTXT); | |||||
| if (!File.Exists(saved_model_path) && !Directory.Exists(saved_model_path) && !File.Exists(saved_model_pb_txt_path) | |||||
| && !Directory.Exists(saved_model_pb_txt_path)) | |||||
| { | |||||
| throw new ValueError($"Trying to load a model of incompatible/unknown type. " + | |||||
| $"'{module_path}' contains neither '{Constants.SAVED_MODEL_FILENAME_PB}' " + | |||||
| $"nor '{Constants.SAVED_MODEL_FILENAME_PBTXT}'."); | |||||
| } | |||||
| var obj = Loader.load(module_path, options: options); | |||||
| return obj; | |||||
| } | |||||
| public static string resolve(string handle) | |||||
| { | |||||
| return MultiImplRegister.GetResolverRegister().Call(handle); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,55 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| internal class MultiImplRegister | |||||
| { | |||||
| private static MultiImplRegister resolver = new MultiImplRegister("resolver", new IResolver[0]); | |||||
| private static MultiImplRegister loader = new MultiImplRegister("loader", new IResolver[0]); | |||||
| static MultiImplRegister() | |||||
| { | |||||
| resolver.add_implementation(new PathResolver()); | |||||
| resolver.add_implementation(new HttpUncompressedFileResolver()); | |||||
| resolver.add_implementation(new GcsCompressedFileResolver()); | |||||
| resolver.add_implementation(new HttpCompressedFileResolver()); | |||||
| } | |||||
| string _name; | |||||
| List<IResolver> _impls; | |||||
| public MultiImplRegister(string name, IEnumerable<IResolver> impls) | |||||
| { | |||||
| _name = name; | |||||
| _impls = impls.ToList(); | |||||
| } | |||||
| public void add_implementation(IResolver resolver) | |||||
| { | |||||
| _impls.Add(resolver); | |||||
| } | |||||
| public string Call(string handle) | |||||
| { | |||||
| foreach (var impl in _impls.Reverse<IResolver>()) | |||||
| { | |||||
| if (impl.IsSupported(handle)) | |||||
| { | |||||
| return impl.Call(handle); | |||||
| } | |||||
| } | |||||
| throw new RuntimeError($"Cannot resolve the handle {handle}"); | |||||
| } | |||||
| public static MultiImplRegister GetResolverRegister() | |||||
| { | |||||
| return resolver; | |||||
| } | |||||
| public static MultiImplRegister GetLoaderRegister() | |||||
| { | |||||
| return loader; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,580 @@ | |||||
| using ICSharpCode.SharpZipLib.Tar; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.ComponentModel; | |||||
| using System.Diagnostics; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Net; | |||||
| using System.Net.Http; | |||||
| using System.Net.Security; | |||||
| using System.Security.Authentication; | |||||
| using System.Threading.Tasks; | |||||
| using System.Web; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| internal static class resolver | |||||
| { | |||||
| public enum ModelLoadFormat | |||||
| { | |||||
| [Description("COMPRESSED")] | |||||
| COMPRESSED, | |||||
| [Description("UNCOMPRESSED")] | |||||
| UNCOMPRESSED, | |||||
| [Description("AUTO")] | |||||
| AUTO | |||||
| } | |||||
| public class DownloadManager | |||||
| { | |||||
| private readonly string _url; | |||||
| private double _last_progress_msg_print_time; | |||||
| private long _total_bytes_downloaded; | |||||
| private int _max_prog_str; | |||||
| private bool _interactive_mode() | |||||
| { | |||||
| return !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("_TFHUB_DOWNLOAD_PROGRESS")); | |||||
| } | |||||
| private void _print_download_progress_msg(string msg, bool flush = false) | |||||
| { | |||||
| if (_interactive_mode()) | |||||
| { | |||||
| // Print progress message to console overwriting previous progress | |||||
| // message. | |||||
| _max_prog_str = Math.Max(_max_prog_str, msg.Length); | |||||
| Console.Write($"\r{msg.PadRight(_max_prog_str)}"); | |||||
| Console.Out.Flush(); | |||||
| //如果flush参数为true,则输出换行符减少干扰交互式界面。 | |||||
| if (flush) | |||||
| Console.WriteLine(); | |||||
| } | |||||
| else | |||||
| { | |||||
| // Interactive progress tracking is disabled. Print progress to the | |||||
| // standard TF log. | |||||
| tf.Logger.Information(msg); | |||||
| } | |||||
| } | |||||
| private void _log_progress(long bytes_downloaded) | |||||
| { | |||||
| // Logs progress information about ongoing module download. | |||||
| _total_bytes_downloaded += bytes_downloaded; | |||||
| var now = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; | |||||
| if (_interactive_mode() || now - _last_progress_msg_print_time > 15) | |||||
| { | |||||
| // Print progress message every 15 secs or if interactive progress | |||||
| // tracking is enabled. | |||||
| _print_download_progress_msg($"Downloading {_url}:" + | |||||
| $"{tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true)}"); | |||||
| _last_progress_msg_print_time = now; | |||||
| } | |||||
| } | |||||
| public DownloadManager(string url) | |||||
| { | |||||
| _url = url; | |||||
| _last_progress_msg_print_time = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; | |||||
| _total_bytes_downloaded = 0; | |||||
| _max_prog_str = 0; | |||||
| } | |||||
| public void download_and_uncompress(Stream fileobj, string dst_path) | |||||
| { | |||||
| // Streams the content for the 'fileobj' and stores the result in dst_path. | |||||
| try | |||||
| { | |||||
| file_utils.extract_tarfile_to_destination(fileobj, dst_path, _log_progress); | |||||
| var total_size_str = tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true); | |||||
| _print_download_progress_msg($"Downloaded {_url}, Total size: {total_size_str}", flush: true); | |||||
| } | |||||
| catch (TarException ex) | |||||
| { | |||||
| throw new IOException($"{_url} does not appear to be a valid module. Inner message:{ex.Message}", ex); | |||||
| } | |||||
| } | |||||
| } | |||||
| private static Dictionary<string, string> _flags = new(); | |||||
| private static readonly string _TFHUB_CACHE_DIR = "TFHUB_CACHE_DIR"; | |||||
| private static readonly string _TFHUB_DOWNLOAD_PROGRESS = "TFHUB_DOWNLOAD_PROGRESS"; | |||||
| private static readonly string _TFHUB_MODEL_LOAD_FORMAT = "TFHUB_MODEL_LOAD_FORMAT"; | |||||
| private static readonly string _TFHUB_DISABLE_CERT_VALIDATION = "TFHUB_DISABLE_CERT_VALIDATION"; | |||||
| private static readonly string _TFHUB_DISABLE_CERT_VALIDATION_VALUE = "true"; | |||||
| static resolver() | |||||
| { | |||||
| set_new_flag("tfhub_model_load_format", "AUTO"); | |||||
| set_new_flag("tfhub_cache_dir", null); | |||||
| } | |||||
| public static string model_load_format() | |||||
| { | |||||
| return get_env_setting(_TFHUB_MODEL_LOAD_FORMAT, "tfhub_model_load_format"); | |||||
| } | |||||
| public static string? get_env_setting(string env_var, string flag_name) | |||||
| { | |||||
| string value = System.Environment.GetEnvironmentVariable(env_var); | |||||
| if (string.IsNullOrEmpty(value)) | |||||
| { | |||||
| if (_flags.ContainsKey(flag_name)) | |||||
| { | |||||
| return _flags[flag_name]; | |||||
| } | |||||
| else | |||||
| { | |||||
| return null; | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| return value; | |||||
| } | |||||
| } | |||||
| public static string tfhub_cache_dir(string default_cache_dir = null, bool use_temp = false) | |||||
| { | |||||
| var cache_dir = get_env_setting(_TFHUB_CACHE_DIR, "tfhub_cache_dir") ?? default_cache_dir; | |||||
| if (string.IsNullOrWhiteSpace(cache_dir) && use_temp) | |||||
| { | |||||
| // Place all TF-Hub modules under <system's temp>/tfhub_modules. | |||||
| cache_dir = Path.Combine(Path.GetTempPath(), "tfhub_modules"); | |||||
| } | |||||
| if (!string.IsNullOrWhiteSpace(cache_dir)) | |||||
| { | |||||
| Console.WriteLine("Using {0} to cache modules.", cache_dir); | |||||
| } | |||||
| return cache_dir; | |||||
| } | |||||
| public static string create_local_module_dir(string cache_dir, string module_name) | |||||
| { | |||||
| Directory.CreateDirectory(cache_dir); | |||||
| return Path.Combine(cache_dir, module_name); | |||||
| } | |||||
| public static void set_new_flag(string name, string value) | |||||
| { | |||||
| string[] tokens = new string[] {_TFHUB_CACHE_DIR, _TFHUB_DISABLE_CERT_VALIDATION, | |||||
| _TFHUB_DISABLE_CERT_VALIDATION_VALUE, _TFHUB_DOWNLOAD_PROGRESS, _TFHUB_MODEL_LOAD_FORMAT}; | |||||
| if (!tokens.Contains(name)) | |||||
| { | |||||
| tf.Logger.Warning($"You are settinng a flag '{name}' that cannot be recognized. The flag you set" + | |||||
| "may not affect anything in tensorflow.hub."); | |||||
| } | |||||
| _flags[name] = value; | |||||
| } | |||||
| public static string _merge_relative_path(string dstPath, string relPath) | |||||
| { | |||||
| return file_utils.merge_relative_path(dstPath, relPath); | |||||
| } | |||||
| public static string _module_descriptor_file(string moduleDir) | |||||
| { | |||||
| return $"{moduleDir}.descriptor.txt"; | |||||
| } | |||||
| public static void _write_module_descriptor_file(string handle, string moduleDir) | |||||
| { | |||||
| var readme = _module_descriptor_file(moduleDir); | |||||
| var content = $"Module: {handle}\nDownload Time: {DateTime.Now}\nDownloader Hostname: {Environment.MachineName} (PID:{Process.GetCurrentProcess().Id})"; | |||||
| tf_utils.atomic_write_string_to_file(readme, content, overwrite: true); | |||||
| } | |||||
| public static string _lock_file_contents(string task_uid) | |||||
| { | |||||
| return $"{Environment.MachineName}.{Process.GetCurrentProcess().Id}.{task_uid}"; | |||||
| } | |||||
| public static string _lock_filename(string moduleDir) | |||||
| { | |||||
| return tf_utils.absolute_path(moduleDir) + ".lock"; | |||||
| } | |||||
| private static string _module_dir(string lockFilename) | |||||
| { | |||||
| var path = Path.GetDirectoryName(Path.GetFullPath(lockFilename)); | |||||
| if (!string.IsNullOrEmpty(path)) | |||||
| { | |||||
| return Path.Combine(path, "hub_modules"); | |||||
| } | |||||
| throw new Exception("Unable to resolve hub_modules directory from lock file name."); | |||||
| } | |||||
| private static string _task_uid_from_lock_file(string lockFilename) | |||||
| { | |||||
| // Returns task UID of the task that created a given lock file. | |||||
| var lockstring = File.ReadAllText(lockFilename); | |||||
| return lockstring.Split('.').Last(); | |||||
| } | |||||
| private static string _temp_download_dir(string moduleDir, string taskUid) | |||||
| { | |||||
| // Returns the name of a temporary directory to download module to. | |||||
| return $"{Path.GetFullPath(moduleDir)}.{taskUid}.tmp"; | |||||
| } | |||||
| private static long _dir_size(string directory) | |||||
| { | |||||
| // Returns total size (in bytes) of the given 'directory'. | |||||
| long size = 0; | |||||
| foreach (var elem in Directory.EnumerateFileSystemEntries(directory)) | |||||
| { | |||||
| var stat = new FileInfo(elem); | |||||
| size += stat.Length; | |||||
| if ((stat.Attributes & FileAttributes.Directory) != 0) | |||||
| size += _dir_size(stat.FullName); | |||||
| } | |||||
| return size; | |||||
| } | |||||
| public static long _locked_tmp_dir_size(string lockFilename) | |||||
| { | |||||
| //Returns the size of the temp dir pointed to by the given lock file. | |||||
| var taskUid = _task_uid_from_lock_file(lockFilename); | |||||
| try | |||||
| { | |||||
| return _dir_size(_temp_download_dir(_module_dir(lockFilename), taskUid)); | |||||
| } | |||||
| catch (DirectoryNotFoundException) | |||||
| { | |||||
| return 0; | |||||
| } | |||||
| } | |||||
| private static void _wait_for_lock_to_disappear(string handle, string lockFile, double lockFileTimeoutSec) | |||||
| { | |||||
| long? lockedTmpDirSize = null; | |||||
| var lockedTmpDirSizeCheckTime = DateTime.Now; | |||||
| var lockFileContent = ""; | |||||
| while (File.Exists(lockFile)) | |||||
| { | |||||
| try | |||||
| { | |||||
| Console.WriteLine($"Module '{handle}' already being downloaded by '{File.ReadAllText(lockFile)}'. Waiting."); | |||||
| if ((DateTime.Now - lockedTmpDirSizeCheckTime).TotalSeconds > lockFileTimeoutSec) | |||||
| { | |||||
| var curLockedTmpDirSize = _locked_tmp_dir_size(lockFile); | |||||
| var curLockFileContent = File.ReadAllText(lockFile); | |||||
| if (curLockedTmpDirSize == lockedTmpDirSize && curLockFileContent == lockFileContent) | |||||
| { | |||||
| Console.WriteLine($"Deleting lock file {lockFile} due to inactivity."); | |||||
| File.Delete(lockFile); | |||||
| break; | |||||
| } | |||||
| lockedTmpDirSize = curLockedTmpDirSize; | |||||
| lockedTmpDirSizeCheckTime = DateTime.Now; | |||||
| lockFileContent = curLockFileContent; | |||||
| } | |||||
| } | |||||
| catch (FileNotFoundException) | |||||
| { | |||||
| // Lock file or temp directory were deleted during check. Continue | |||||
| // to check whether download succeeded or we need to start our own | |||||
| // download. | |||||
| } | |||||
| System.Threading.Thread.Sleep(5000); | |||||
| } | |||||
| } | |||||
| public static async Task<string> atomic_download_async( | |||||
| string handle, | |||||
| Func<string, string, Task> downloadFn, | |||||
| string moduleDir, | |||||
| int lock_file_timeout_sec = 10 * 60) | |||||
| { | |||||
| var lockFile = _lock_filename(moduleDir); | |||||
| var taskUid = Guid.NewGuid().ToString("N"); | |||||
| var lockContents = _lock_file_contents(taskUid); | |||||
| var tmpDir = _temp_download_dir(moduleDir, taskUid); | |||||
| // Function to check whether model has already been downloaded. | |||||
| Func<bool> checkModuleExists = () => | |||||
| Directory.Exists(moduleDir) && | |||||
| Directory.EnumerateFileSystemEntries(moduleDir).Any(); | |||||
| // Check whether the model has already been downloaded before locking | |||||
| // the destination path. | |||||
| if (checkModuleExists()) | |||||
| { | |||||
| return moduleDir; | |||||
| } | |||||
| // Attempt to protect against cases of processes being cancelled with | |||||
| // KeyboardInterrupt by using a try/finally clause to remove the lock | |||||
| // and tmp_dir. | |||||
| while (true) | |||||
| { | |||||
| try | |||||
| { | |||||
| tf_utils.atomic_write_string_to_file(lockFile, lockContents, false); | |||||
| // Must test condition again, since another process could have created | |||||
| // the module and deleted the old lock file since last test. | |||||
| if (checkModuleExists()) | |||||
| { | |||||
| // Lock file will be deleted in the finally-clause. | |||||
| return moduleDir; | |||||
| } | |||||
| if (Directory.Exists(moduleDir)) | |||||
| { | |||||
| Directory.Delete(moduleDir, true); | |||||
| } | |||||
| break; // Proceed to downloading the module. | |||||
| } | |||||
| // These errors are believed to be permanent problems with the | |||||
| // module_dir that justify failing the download. | |||||
| catch (FileNotFoundException) | |||||
| { | |||||
| throw; | |||||
| } | |||||
| catch (UnauthorizedAccessException) | |||||
| { | |||||
| throw; | |||||
| } | |||||
| catch (IOException) | |||||
| { | |||||
| throw; | |||||
| } | |||||
| // All other errors are retried. | |||||
| // TODO(b/144424849): Retrying an AlreadyExistsError from the atomic write | |||||
| // should be good enough, but see discussion about misc filesystem types. | |||||
| // TODO(b/144475403): How atomic is the overwrite=False check? | |||||
| catch (Exception) | |||||
| { | |||||
| } | |||||
| // Wait for lock file to disappear. | |||||
| _wait_for_lock_to_disappear(handle, lockFile, lock_file_timeout_sec); | |||||
| // At this point we either deleted a lock or a lock got removed by the | |||||
| // owner or another process. Perform one more iteration of the while-loop, | |||||
| // we would either terminate due tf.compat.v1.gfile.Exists(module_dir) or | |||||
| // because we would obtain a lock ourselves, or wait again for the lock to | |||||
| // disappear. | |||||
| } | |||||
| // Lock file acquired. | |||||
| tf.Logger.Information($"Downloading TF-Hub Module '{handle}'..."); | |||||
| Directory.CreateDirectory(tmpDir); | |||||
| await downloadFn(handle, tmpDir); | |||||
| // Write module descriptor to capture information about which module was | |||||
| // downloaded by whom and when. The file stored at the same level as a | |||||
| // directory in order to keep the content of the 'model_dir' exactly as it | |||||
| // was define by the module publisher. | |||||
| // | |||||
| // Note: The descriptor is written purely to help the end-user to identify | |||||
| // which directory belongs to which module. The descriptor is not part of the | |||||
| // module caching protocol and no code in the TF-Hub library reads its | |||||
| // content. | |||||
| _write_module_descriptor_file(handle, moduleDir); | |||||
| try | |||||
| { | |||||
| Directory.Move(tmpDir, moduleDir); | |||||
| Console.WriteLine($"Downloaded TF-Hub Module '{handle}'."); | |||||
| } | |||||
| catch (IOException e) | |||||
| { | |||||
| Console.WriteLine(e.Message); | |||||
| Console.WriteLine($"Failed to move {tmpDir} to {moduleDir}"); | |||||
| // Keep the temp directory so we will retry building vocabulary later. | |||||
| } | |||||
| // Temp directory is owned by the current process, remove it. | |||||
| try | |||||
| { | |||||
| Directory.Delete(tmpDir, true); | |||||
| } | |||||
| catch (DirectoryNotFoundException) | |||||
| { | |||||
| } | |||||
| // Lock file exists and is owned by this process. | |||||
| try | |||||
| { | |||||
| var contents = File.ReadAllText(lockFile); | |||||
| if (contents == lockContents) | |||||
| { | |||||
| File.Delete(lockFile); | |||||
| } | |||||
| } | |||||
| catch (Exception) | |||||
| { | |||||
| } | |||||
| return moduleDir; | |||||
| } | |||||
| } | |||||
| internal interface IResolver | |||||
| { | |||||
| string Call(string handle); | |||||
| bool IsSupported(string handle); | |||||
| } | |||||
| internal class PathResolver : IResolver | |||||
| { | |||||
| public string Call(string handle) | |||||
| { | |||||
| if (!File.Exists(handle) && !Directory.Exists(handle)) | |||||
| { | |||||
| throw new IOException($"{handle} does not exist in file system."); | |||||
| } | |||||
| return handle; | |||||
| } | |||||
| public bool IsSupported(string handle) | |||||
| { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| public abstract class HttpResolverBase : IResolver | |||||
| { | |||||
| private readonly HttpClient httpClient; | |||||
| private SslProtocol sslProtocol; | |||||
| private RemoteCertificateValidationCallback certificateValidator; | |||||
| protected HttpResolverBase() | |||||
| { | |||||
| httpClient = new HttpClient(); | |||||
| _maybe_disable_cert_validation(); | |||||
| } | |||||
| public abstract string Call(string handle); | |||||
| public abstract bool IsSupported(string handle); | |||||
| protected async Task<Stream> GetLocalFileStreamAsync(string filePath) | |||||
| { | |||||
| try | |||||
| { | |||||
| var fs = new FileStream(filePath, FileMode.Open, FileAccess.Read); | |||||
| return await Task.FromResult(fs); | |||||
| } | |||||
| catch (Exception ex) | |||||
| { | |||||
| Console.WriteLine($"Failed to read file stream: {ex.Message}"); | |||||
| return null; | |||||
| } | |||||
| } | |||||
| protected async Task<Stream> GetFileStreamAsync(string filePath) | |||||
| { | |||||
| if (!is_http_protocol(filePath)) | |||||
| { | |||||
| // If filePath is not an HTTP(S) URL, delegate to a file resolver. | |||||
| return await GetLocalFileStreamAsync(filePath); | |||||
| } | |||||
| var request = new HttpRequestMessage(HttpMethod.Get, filePath); | |||||
| var response = await _call_urlopen(request); | |||||
| if (response.IsSuccessStatusCode) | |||||
| { | |||||
| return await response.Content.ReadAsStreamAsync(); | |||||
| } | |||||
| else | |||||
| { | |||||
| Console.WriteLine($"Failed to fetch file stream: {response.StatusCode} - {response.ReasonPhrase}"); | |||||
| return null; | |||||
| } | |||||
| } | |||||
| protected void SetUrlContext(SslProtocol protocol, RemoteCertificateValidationCallback validator) | |||||
| { | |||||
| sslProtocol = protocol; | |||||
| certificateValidator = validator; | |||||
| } | |||||
| public static string append_format_query(string handle, (string, string) formatQuery) | |||||
| { | |||||
| var parsed = new Uri(handle); | |||||
| var queryBuilder = HttpUtility.ParseQueryString(parsed.Query); | |||||
| queryBuilder.Add(formatQuery.Item1, formatQuery.Item2); | |||||
| parsed = new UriBuilder(parsed.Scheme, parsed.Host, parsed.Port, parsed.AbsolutePath, | |||||
| "?" + queryBuilder.ToString()).Uri; | |||||
| return parsed.ToString(); | |||||
| } | |||||
| protected bool is_http_protocol(string handle) | |||||
| { | |||||
| return handle.StartsWith("http://") || handle.StartsWith("https://"); | |||||
| } | |||||
| protected async Task<HttpResponseMessage> _call_urlopen(HttpRequestMessage request) | |||||
| { | |||||
| if (sslProtocol != null) | |||||
| { | |||||
| var handler = new HttpClientHandler() | |||||
| { | |||||
| SslProtocols = sslProtocol.AsEnum(), | |||||
| }; | |||||
| if (certificateValidator != null) | |||||
| { | |||||
| handler.ServerCertificateCustomValidationCallback = (x, y, z, w) => | |||||
| { | |||||
| return certificateValidator(x, y, z, w); | |||||
| }; | |||||
| } | |||||
| var client = new HttpClient(handler); | |||||
| return await client.SendAsync(request); | |||||
| } | |||||
| else | |||||
| { | |||||
| return await httpClient.SendAsync(request); | |||||
| } | |||||
| } | |||||
| protected void _maybe_disable_cert_validation() | |||||
| { | |||||
| if (Environment.GetEnvironmentVariable("_TFHUB_DISABLE_CERT_VALIDATION") == "_TFHUB_DISABLE_CERT_VALIDATION_VALUE") | |||||
| { | |||||
| ServicePointManager.ServerCertificateValidationCallback = (_, _, _, _) => true; | |||||
| Console.WriteLine("Disabled certificate validation for resolving handles."); | |||||
| } | |||||
| } | |||||
| } | |||||
| public class SslProtocol | |||||
| { | |||||
| private readonly string protocolString; | |||||
| public static readonly SslProtocol Tls = new SslProtocol("TLS"); | |||||
| public static readonly SslProtocol Tls11 = new SslProtocol("TLS 1.1"); | |||||
| public static readonly SslProtocol Tls12 = new SslProtocol("TLS 1.2"); | |||||
| private SslProtocol(string protocolString) | |||||
| { | |||||
| this.protocolString = protocolString; | |||||
| } | |||||
| public SslProtocols AsEnum() | |||||
| { | |||||
| switch (protocolString.ToUpper()) | |||||
| { | |||||
| case "TLS": | |||||
| return SslProtocols.Tls; | |||||
| case "TLS 1.1": | |||||
| return SslProtocols.Tls11; | |||||
| case "TLS 1.2": | |||||
| return SslProtocols.Tls12; | |||||
| default: | |||||
| throw new ArgumentException($"Unknown SSL/TLS protocol: {protocolString}"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,80 @@ | |||||
| using System; | |||||
| using System.IO; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| internal class tf_utils | |||||
| { | |||||
| public static string bytes_to_readable_str(long? numBytes, bool includeB = false) | |||||
| { | |||||
| if (numBytes == null) return numBytes.ToString(); | |||||
| var num = (double)numBytes; | |||||
| if (num < 1024) | |||||
| { | |||||
| return $"{(long)num}{(includeB ? "B" : "")}"; | |||||
| } | |||||
| num /= 1 << 10; | |||||
| if (num < 1024) | |||||
| { | |||||
| return $"{num:F2}k{(includeB ? "B" : "")}"; | |||||
| } | |||||
| num /= 1 << 10; | |||||
| if (num < 1024) | |||||
| { | |||||
| return $"{num:F2}M{(includeB ? "B" : "")}"; | |||||
| } | |||||
| num /= 1 << 10; | |||||
| return $"{num:F2}G{(includeB ? "B" : "")}"; | |||||
| } | |||||
| public static void atomic_write_string_to_file(string filename, string contents, bool overwrite) | |||||
| { | |||||
| var tempPath = $"{filename}.tmp.{Guid.NewGuid():N}"; | |||||
| using (var fileStream = new FileStream(tempPath, FileMode.Create)) | |||||
| { | |||||
| using (var writer = new StreamWriter(fileStream)) | |||||
| { | |||||
| writer.Write(contents); | |||||
| writer.Flush(); | |||||
| } | |||||
| } | |||||
| try | |||||
| { | |||||
| if (File.Exists(filename)) | |||||
| { | |||||
| if (overwrite) | |||||
| { | |||||
| File.Delete(filename); | |||||
| File.Move(tempPath, filename); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| File.Move(tempPath, filename); | |||||
| } | |||||
| } | |||||
| catch | |||||
| { | |||||
| File.Delete(tempPath); | |||||
| throw; | |||||
| } | |||||
| } | |||||
| public static string absolute_path(string path) | |||||
| { | |||||
| if (path.Contains("://")) | |||||
| { | |||||
| return path; | |||||
| } | |||||
| return Path.GetFullPath(path); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,46 @@ | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.HubAPI; | |||||
| namespace Tensorflow.Hub.Unittest | |||||
| { | |||||
| [TestClass] | |||||
| public class KerasLayerTest | |||||
| { | |||||
| [TestMethod] | |||||
| public void SmallBert() | |||||
| { | |||||
| var layer = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1"); | |||||
| var input_type_ids = tf.convert_to_tensor(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); | |||||
| input_type_ids = tf.reshape(input_type_ids, (1, 128)); | |||||
| var input_word_ids = tf.convert_to_tensor(new int[] { 101, 2129, 2024, 2017, 102, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); | |||||
| input_word_ids = tf.reshape(input_word_ids, (1, 128)); | |||||
| var input_mask = tf.convert_to_tensor(new int[] { 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: dtypes.int32); | |||||
| input_mask = tf.reshape(input_mask, (1, 128)); | |||||
| var result = layer.Apply(new Tensors(input_type_ids, input_word_ids, input_mask)); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,23 @@ | |||||
| <Project Sdk="Microsoft.NET.Sdk"> | |||||
| <PropertyGroup> | |||||
| <TargetFramework>net7</TargetFramework> | |||||
| <ImplicitUsings>enable</ImplicitUsings> | |||||
| <Nullable>enable</Nullable> | |||||
| <IsPackable>false</IsPackable> | |||||
| </PropertyGroup> | |||||
| <ItemGroup> | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" /> | |||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.2.10" /> | |||||
| <PackageReference Include="MSTest.TestFramework" Version="2.2.10" /> | |||||
| <PackageReference Include="coverlet.collector" Version="3.1.2" /> | |||||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.11.2" /> | |||||
| </ItemGroup> | |||||
| <ItemGroup> | |||||
| <ProjectReference Include="..\..\src\TensorflowNET.Hub\Tensorflow.Hub.csproj" /> | |||||
| </ItemGroup> | |||||
| </Project> | |||||
| @@ -0,0 +1 @@ | |||||
| global using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||