| @@ -23,6 +23,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorflowNET.Hub\Tensorflow.Hub.csproj", "{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}" | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub.Unittest", "test\TensorflowNET.Hub.Unittest\Tensorflow.Hub.Unittest.csproj", "{7DEA8760-E401-4872-81F3-405F185A13A0}" | |||
| EndProject | |||
| Global | |||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
| Debug|Any CPU = Debug|Any CPU | |||
| @@ -153,6 +157,30 @@ Global | |||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 | |||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.Build.0 = Debug|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.ActiveCfg = Release|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.Build.0 = Release|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.Build.0 = Release|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.Build.0 = Debug|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.ActiveCfg = Release|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.Build.0 = Release|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.Build.0 = Release|Any CPU | |||
| EndGlobalSection | |||
| GlobalSection(SolutionProperties) = preSolution | |||
| HideSolutionNode = FALSE | |||
| @@ -207,9 +207,24 @@ namespace Tensorflow | |||
| } | |||
| public override string ToString() | |||
| => items.Count() == 1 | |||
| ? items.First().ToString() | |||
| : items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); | |||
| { | |||
| if(items.Count == 1) | |||
| { | |||
| return items[0].ToString(); | |||
| } | |||
| else | |||
| { | |||
| StringBuilder sb = new StringBuilder(); | |||
| sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); | |||
| for(int i = 0; i < items.Count; i++) | |||
| { | |||
| var tensor = items[i]; | |||
| sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); | |||
| } | |||
| sb.Append("]\n"); | |||
| return sb.ToString(); | |||
| } | |||
| } | |||
| public void Dispose() | |||
| { | |||
| @@ -301,6 +301,17 @@ namespace Tensorflow | |||
| type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; | |||
| } | |||
| public static bool is_unsigned(this TF_DataType type) | |||
| { | |||
| return type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || | |||
| type == TF_DataType.TF_UINT64; | |||
| } | |||
| public static bool is_bool(this TF_DataType type) | |||
| { | |||
| return type == TF_DataType.TF_BOOL; | |||
| } | |||
| public static bool is_floating(this TF_DataType type) | |||
| { | |||
| return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE; | |||
| @@ -22,9 +22,9 @@ namespace Tensorflow.Keras.Engine | |||
| // If dtype is DT_FLOAT, provide a uniform unit scaling initializer | |||
| if (dtype.is_floating()) | |||
| initializer = tf.glorot_uniform_initializer; | |||
| else if (dtype.is_integer()) | |||
| else if (dtype.is_integer() || dtype.is_unsigned() || dtype.is_bool()) | |||
| initializer = tf.zeros_initializer; | |||
| else | |||
| else if(getter is null) | |||
| throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | |||
| } | |||
| @@ -36,5 +36,9 @@ namespace Tensorflow.Keras.Saving | |||
| public bool? Stateful { get; set; } | |||
| [JsonProperty("model_config")] | |||
| public KerasModelConfig? ModelConfig { get; set; } | |||
| [JsonProperty("sparse")] | |||
| public bool Sparse { get; set; } | |||
| [JsonProperty("ragged")] | |||
| public bool Ragged { get; set; } | |||
| } | |||
| } | |||
| @@ -401,13 +401,22 @@ namespace Tensorflow.Keras.Saving | |||
| private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata) | |||
| { | |||
| if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||
| if (identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||
| { | |||
| return RevivedLayer.init_from_metadata(metadata); | |||
| } | |||
| else if(identifier == SavedModel.Constants.MODEL_IDENTIFIER || identifier == SavedModel.Constants.SEQUENTIAL_IDENTIFIER | |||
| || identifier == SavedModel.Constants.NETWORK_IDENTIFIER) | |||
| { | |||
| return RevivedNetwork.init_from_metadata(metadata); | |||
| } | |||
| else if(identifier == SavedModel.Constants.INPUT_LAYER_IDENTIFIER) | |||
| { | |||
| return RevivedInputLayer.init_from_metadata(metadata); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(); | |||
| throw new ValueError($"Cannot revive the layer {identifier}."); | |||
| } | |||
| } | |||
| @@ -1,15 +1,46 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| public class RevivedInputLayer: Layer | |||
| public class RevivedInputLayer: InputLayer | |||
| { | |||
| private RevivedInputLayer(): base(null) | |||
| protected RevivedConfig _config = null; | |||
| private RevivedInputLayer(InputLayerArgs args): base(args) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public override IKerasConfig get_config() | |||
| { | |||
| return _config; | |||
| } | |||
| public static (RevivedInputLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||
| { | |||
| InputLayerArgs args = new InputLayerArgs() | |||
| { | |||
| Name = metadata.Name, | |||
| DType = metadata.DType, | |||
| Sparse = metadata.Sparse, | |||
| Ragged = metadata.Ragged, | |||
| BatchInputShape = metadata.BatchInputShape | |||
| }; | |||
| RevivedInputLayer revived_obj = new RevivedInputLayer(args); | |||
| revived_obj._config = new RevivedConfig() { Config = metadata.Config }; | |||
| return (revived_obj, Loader.setattr); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| return $"Customized keras input layer: {Name}."; | |||
| } | |||
| } | |||
| } | |||
| @@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| return (revived_obj, ReviveUtils._revive_setter); | |||
| } | |||
| private RevivedConfig _config = null; | |||
| protected RevivedConfig _config = null; | |||
| public object keras_api | |||
| { | |||
| @@ -70,7 +70,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| } | |||
| } | |||
| public RevivedLayer(LayerArgs args): base(args) | |||
| protected RevivedLayer(LayerArgs args): base(args) | |||
| { | |||
| } | |||
| @@ -84,17 +84,5 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| return _config; | |||
| } | |||
| //protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| //{ | |||
| // if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||
| // { | |||
| // return base.Call(inputs, state, training); | |||
| // } | |||
| // else | |||
| // { | |||
| // return (func as Function).Apply(inputs); | |||
| // } | |||
| //} | |||
| } | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Utils; | |||
| namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| public class RevivedNetwork: RevivedLayer | |||
| { | |||
| private RevivedNetwork(LayerArgs args) : base(args) | |||
| { | |||
| } | |||
| public static (RevivedNetwork, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||
| { | |||
| RevivedNetwork revived_obj = new(new LayerArgs() { Name = metadata.Name }); | |||
| // TODO(Rinne): with utils.no_automatic_dependency_tracking_scope(revived_obj) | |||
| // TODO(Rinne): revived_obj._expects_training_arg | |||
| var config = metadata.Config; | |||
| if (generic_utils.validate_config(config)) | |||
| { | |||
| revived_obj._config = new RevivedConfig() { Config = config }; | |||
| } | |||
| if(metadata.ActivityRegularizer is not null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return (revived_obj, ReviveUtils._revive_setter); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| return $"Customized keras Network: {Name}."; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,57 @@ | |||
| using System.IO; | |||
| using System.Threading.Tasks; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public class GcsCompressedFileResolver : IResolver | |||
| { | |||
| const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; | |||
| public string Call(string handle) | |||
| { | |||
| var module_dir = _module_dir(handle); | |||
| return resolver.atomic_download_async(handle, download, module_dir, LOCK_FILE_TIMEOUT_SEC) | |||
| .GetAwaiter().GetResult(); | |||
| } | |||
| public bool IsSupported(string handle) | |||
| { | |||
| return handle.StartsWith("gs://") && _is_tarfile(handle); | |||
| } | |||
| private async Task download(string handle, string tmp_dir) | |||
| { | |||
| new resolver.DownloadManager(handle).download_and_uncompress( | |||
| new FileStream(handle, FileMode.Open, FileAccess.Read), tmp_dir); | |||
| await Task.Run(() => { }); | |||
| } | |||
| private static string _module_dir(string handle) | |||
| { | |||
| var cache_dir = resolver.tfhub_cache_dir(use_temp: true); | |||
| var sha1 = ComputeSha1(handle); | |||
| return resolver.create_local_module_dir(cache_dir, sha1); | |||
| } | |||
| private static bool _is_tarfile(string filename) | |||
| { | |||
| return filename.EndsWith(".tar") || filename.EndsWith(".tar.gz") || filename.EndsWith(".tgz"); | |||
| } | |||
| private static string ComputeSha1(string s) | |||
| { | |||
| using (var sha = new System.Security.Cryptography.SHA1Managed()) | |||
| { | |||
| var bytes = System.Text.Encoding.UTF8.GetBytes(s); | |||
| var hash = sha.ComputeHash(bytes); | |||
| var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); | |||
| foreach (var b in hash) | |||
| { | |||
| stringBuilder.Append(b.ToString("x2")); | |||
| } | |||
| return stringBuilder.ToString(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,78 @@ | |||
| using System; | |||
| using System.Net.Http; | |||
| using System.Threading.Tasks; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public class HttpCompressedFileResolver : HttpResolverBase | |||
| { | |||
| const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; // 10 minutes | |||
| private static readonly (string, string) _COMPRESSED_FORMAT_QUERY = | |||
| ("tf-hub-format", "compressed"); | |||
| private static string _module_dir(string handle) | |||
| { | |||
| var cache_dir = resolver.tfhub_cache_dir(use_temp: true); | |||
| var sha1 = ComputeSha1(handle); | |||
| return resolver.create_local_module_dir(cache_dir, sha1); | |||
| } | |||
| public override bool IsSupported(string handle) | |||
| { | |||
| if (!is_http_protocol(handle)) | |||
| { | |||
| return false; | |||
| } | |||
| var load_format = resolver.model_load_format(); | |||
| return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.COMPRESSED) | |||
| || load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.AUTO); | |||
| } | |||
| public override string Call(string handle) | |||
| { | |||
| var module_dir = _module_dir(handle); | |||
| return resolver.atomic_download_async( | |||
| handle, | |||
| download, | |||
| module_dir, | |||
| LOCK_FILE_TIMEOUT_SEC | |||
| ).GetAwaiter().GetResult(); | |||
| } | |||
| private async Task download(string handle, string tmp_dir) | |||
| { | |||
| var client = new HttpClient(); | |||
| var response = await client.GetAsync(_append_compressed_format_query(handle)); | |||
| using (var httpStream = await response.Content.ReadAsStreamAsync()) | |||
| { | |||
| new resolver.DownloadManager(handle).download_and_uncompress(httpStream, tmp_dir); | |||
| } | |||
| } | |||
| private string _append_compressed_format_query(string handle) | |||
| { | |||
| return append_format_query(handle, _COMPRESSED_FORMAT_QUERY); | |||
| } | |||
| private static string ComputeSha1(string s) | |||
| { | |||
| using (var sha = new System.Security.Cryptography.SHA1Managed()) | |||
| { | |||
| var bytes = System.Text.Encoding.UTF8.GetBytes(s); | |||
| var hash = sha.ComputeHash(bytes); | |||
| var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); | |||
| foreach (var b in hash) | |||
| { | |||
| stringBuilder.Append(b.ToString("x2")); | |||
| } | |||
| return stringBuilder.ToString(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,65 @@ | |||
| using System; | |||
| using System.Net; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public class HttpUncompressedFileResolver : HttpResolverBase | |||
| { | |||
| private readonly PathResolver _pathResolver; | |||
| public HttpUncompressedFileResolver() | |||
| { | |||
| _pathResolver = new PathResolver(); | |||
| } | |||
| public override string Call(string handle) | |||
| { | |||
| handle = AppendUncompressedFormatQuery(handle); | |||
| var gsLocation = RequestGcsLocation(handle); | |||
| return _pathResolver.Call(gsLocation); | |||
| } | |||
| public override bool IsSupported(string handle) | |||
| { | |||
| if (!is_http_protocol(handle)) | |||
| { | |||
| return false; | |||
| } | |||
| var load_format = resolver.model_load_format(); | |||
| return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.UNCOMPRESSED); | |||
| } | |||
| protected virtual string AppendUncompressedFormatQuery(string handle) | |||
| { | |||
| return append_format_query(handle, ("tf-hub-format", "uncompressed")); | |||
| } | |||
| protected virtual string RequestGcsLocation(string handleWithParams) | |||
| { | |||
| var request = WebRequest.Create(handleWithParams); | |||
| var response = request.GetResponse() as HttpWebResponse; | |||
| if (response == null) | |||
| { | |||
| throw new Exception("Failed to get a response from the server."); | |||
| } | |||
| var statusCode = (int)response.StatusCode; | |||
| if (statusCode != 303) | |||
| { | |||
| throw new Exception($"Expected 303 for GCS location lookup but got HTTP {statusCode} {response.StatusDescription}"); | |||
| } | |||
| var location = response.Headers["Location"]; | |||
| if (!location.StartsWith("gs://")) | |||
| { | |||
| throw new Exception($"Expected Location:GS path but received {location}"); | |||
| } | |||
| return location; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,157 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public class KerasLayer : Layer | |||
| { | |||
| private string _handle; | |||
| private LoadOptions? _load_options; | |||
| private Trackable _func; | |||
| private Func<Tensors, Tensors> _callable; | |||
| public KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) : | |||
| base(new Keras.ArgsDefinition.LayerArgs() { Trainable = trainable }) | |||
| { | |||
| _handle = handle; | |||
| _load_options = load_options; | |||
| _func = load_module(_handle, _load_options); | |||
| _track_trackable(_func, "_func"); | |||
| // TODO(Rinne): deal with _is_hub_module_v1. | |||
| _callable = _get_callable(); | |||
| _setup_layer(trainable); | |||
| } | |||
| private void _setup_layer(bool trainable = false) | |||
| { | |||
| HashSet<string> trainable_variables; | |||
| if (_func is Layer layer) | |||
| { | |||
| foreach (var v in layer.TrainableVariables) | |||
| { | |||
| _add_existing_weight(v, true); | |||
| } | |||
| trainable_variables = new HashSet<string>(layer.TrainableVariables.Select(v => v.UniqueId)); | |||
| } | |||
| else if (_func.CustomizedFields.TryGetValue("trainable_variables", out var obj) && obj is IEnumerable<Trackable> trackables) | |||
| { | |||
| foreach (var trackable in trackables) | |||
| { | |||
| if (trackable is IVariableV1 v) | |||
| { | |||
| _add_existing_weight(v, true); | |||
| } | |||
| } | |||
| trainable_variables = new HashSet<string>(trackables.Where(t => t is IVariableV1).Select(t => (t as IVariableV1).UniqueId)); | |||
| } | |||
| else | |||
| { | |||
| trainable_variables = new HashSet<string>(); | |||
| } | |||
| if (_func is Layer) | |||
| { | |||
| layer = (Layer)_func; | |||
| foreach (var v in layer.Variables) | |||
| { | |||
| if (!trainable_variables.Contains(v.UniqueId)) | |||
| { | |||
| _add_existing_weight(v, false); | |||
| } | |||
| } | |||
| } | |||
| else if (_func.CustomizedFields.TryGetValue("variables", out var obj) && obj is IEnumerable<Trackable> total_trackables) | |||
| { | |||
| foreach (var trackable in total_trackables) | |||
| { | |||
| if (trackable is IVariableV1 v && !trainable_variables.Contains(v.UniqueId)) | |||
| { | |||
| _add_existing_weight(v, false); | |||
| } | |||
| } | |||
| } | |||
| if (_func.CustomizedFields.ContainsKey("regularization_losses")) | |||
| { | |||
| if ((_func.CustomizedFields["regularization_losses"] as ListWrapper)?.Count > 0) | |||
| { | |||
| throw new NotImplementedException("The regularization_losses loading has not been supported yet, " + | |||
| "please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues to let us know and add a feature."); | |||
| } | |||
| } | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| { | |||
| _check_trainability(); | |||
| // TODO(Rinne): deal with training_argument | |||
| var result = _callable(inputs); | |||
| return _apply_output_shape_if_set(inputs, result); | |||
| } | |||
| private void _check_trainability() | |||
| { | |||
| if (!Trainable) return; | |||
| // TODO(Rinne): deal with _is_hub_module_v1 and signature | |||
| if (TrainableWeights is null || TrainableWeights.Count == 0) | |||
| { | |||
| tf.Logger.Error("hub.KerasLayer is trainable but has zero trainable weights."); | |||
| } | |||
| } | |||
| private Tensors _apply_output_shape_if_set(Tensors inputs, Tensors result) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| return result; | |||
| } | |||
| private void _add_existing_weight(IVariableV1 weight, bool? trainable = null) | |||
| { | |||
| bool is_trainable; | |||
| if (trainable is null) | |||
| { | |||
| is_trainable = weight.Trainable; | |||
| } | |||
| else | |||
| { | |||
| is_trainable = trainable.Value; | |||
| } | |||
| add_weight(weight.Name, weight.shape, weight.dtype, trainable: is_trainable, getter: x => weight); | |||
| } | |||
| private Func<Tensors, Tensors> _get_callable() | |||
| { | |||
| if (_func is Layer layer) | |||
| { | |||
| return x => layer.Apply(x); | |||
| } | |||
| if (_func.CustomizedFields.ContainsKey("__call__")) | |||
| { | |||
| if (_func.CustomizedFields["__call__"] is RestoredFunction function) | |||
| { | |||
| return x => function.Apply(x); | |||
| } | |||
| } | |||
| throw new ValueError("Cannot get the callable from the model."); | |||
| } | |||
| private static Trackable load_module(string handle, LoadOptions? load_options = null) | |||
| { | |||
| //var set_load_options = load_options ?? LoadContext.get_load_option(); | |||
| return module_v2.load(handle, load_options); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,17 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <PropertyGroup> | |||
| <TargetFrameworks>netstandard2.0;net6;net7</TargetFrameworks> | |||
| <LangVersion>11</LangVersion> | |||
| <Nullable>enable</Nullable> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="SharpCompress" Version="0.33.0" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -0,0 +1,74 @@ | |||
| using SharpCompress.Common; | |||
| using SharpCompress.Readers; | |||
| using System; | |||
| using System.IO; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| internal static class file_utils | |||
| { | |||
| //public static void extract_file(TarInputStream tgz, TarEntry tarInfo, string dstPath, uint bufferSize = 10 << 20, Action<long> logFunction = null) | |||
| //{ | |||
| // using (var src = tgz.GetNextEntry() == tarInfo ? tgz : null) | |||
| // { | |||
| // if (src is null) | |||
| // { | |||
| // return; | |||
| // } | |||
| // using (var dst = File.Create(dstPath)) | |||
| // { | |||
| // var buffer = new byte[bufferSize]; | |||
| // int count; | |||
| // while ((count = src.Read(buffer, 0, buffer.Length)) > 0) | |||
| // { | |||
| // dst.Write(buffer, 0, count); | |||
| // logFunction?.Invoke(count); | |||
| // } | |||
| // } | |||
| // } | |||
| //} | |||
| public static void extract_tarfile_to_destination(Stream fileobj, string dst_path, Action<long> logFunction = null) | |||
| { | |||
| using (IReader reader = ReaderFactory.Open(fileobj)) | |||
| { | |||
| while (reader.MoveToNextEntry()) | |||
| { | |||
| if (!reader.Entry.IsDirectory) | |||
| { | |||
| reader.WriteEntryToDirectory( | |||
| dst_path, | |||
| new ExtractionOptions() { ExtractFullPath = true, Overwrite = true } | |||
| ); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| public static string merge_relative_path(string dstPath, string relPath) | |||
| { | |||
| var cleanRelPath = Path.GetFullPath(relPath).TrimStart('/', '\\'); | |||
| if (cleanRelPath == ".") | |||
| { | |||
| return dstPath; | |||
| } | |||
| if (cleanRelPath.StartsWith("..") || Path.IsPathRooted(cleanRelPath)) | |||
| { | |||
| throw new InvalidDataException($"Relative path '{relPath}' is invalid."); | |||
| } | |||
| var merged = Path.Combine(dstPath, cleanRelPath); | |||
| if (!merged.StartsWith(dstPath)) | |||
| { | |||
| throw new InvalidDataException($"Relative path '{relPath}' is invalid. Failed to merge with '{dstPath}'."); | |||
| } | |||
| return merged; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,17 @@ | |||
| using Tensorflow.Hub; | |||
| namespace Tensorflow | |||
| { | |||
| public static class HubAPI | |||
| { | |||
| public static HubMethods hub { get; } = new HubMethods(); | |||
| } | |||
| public class HubMethods | |||
| { | |||
| public KerasLayer KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) | |||
| { | |||
| return new KerasLayer(handle, trainable, load_options); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| using System.IO; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| internal static class module_v2 | |||
| { | |||
| public static Trackable load(string handle, LoadOptions? options) | |||
| { | |||
| var module_path = resolve(handle); | |||
| // TODO(Rinne): deal with is_hub_module_v1 | |||
| var saved_model_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PB); | |||
| var saved_model_pb_txt_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PBTXT); | |||
| if (!File.Exists(saved_model_path) && !Directory.Exists(saved_model_path) && !File.Exists(saved_model_pb_txt_path) | |||
| && !Directory.Exists(saved_model_pb_txt_path)) | |||
| { | |||
| throw new ValueError($"Trying to load a model of incompatible/unknown type. " + | |||
| $"'{module_path}' contains neither '{Constants.SAVED_MODEL_FILENAME_PB}' " + | |||
| $"nor '{Constants.SAVED_MODEL_FILENAME_PBTXT}'."); | |||
| } | |||
| var obj = Loader.load(module_path, options: options); | |||
| return obj; | |||
| } | |||
| public static string resolve(string handle) | |||
| { | |||
| return MultiImplRegister.GetResolverRegister().Call(handle); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,55 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| internal class MultiImplRegister | |||
| { | |||
| private static MultiImplRegister resolver = new MultiImplRegister("resolver", new IResolver[0]); | |||
| private static MultiImplRegister loader = new MultiImplRegister("loader", new IResolver[0]); | |||
| static MultiImplRegister() | |||
| { | |||
| resolver.add_implementation(new PathResolver()); | |||
| resolver.add_implementation(new HttpUncompressedFileResolver()); | |||
| resolver.add_implementation(new GcsCompressedFileResolver()); | |||
| resolver.add_implementation(new HttpCompressedFileResolver()); | |||
| } | |||
| string _name; | |||
| List<IResolver> _impls; | |||
| public MultiImplRegister(string name, IEnumerable<IResolver> impls) | |||
| { | |||
| _name = name; | |||
| _impls = impls.ToList(); | |||
| } | |||
| public void add_implementation(IResolver resolver) | |||
| { | |||
| _impls.Add(resolver); | |||
| } | |||
| public string Call(string handle) | |||
| { | |||
| foreach (var impl in _impls.Reverse<IResolver>()) | |||
| { | |||
| if (impl.IsSupported(handle)) | |||
| { | |||
| return impl.Call(handle); | |||
| } | |||
| } | |||
| throw new RuntimeError($"Cannot resolve the handle {handle}"); | |||
| } | |||
| public static MultiImplRegister GetResolverRegister() | |||
| { | |||
| return resolver; | |||
| } | |||
| public static MultiImplRegister GetLoaderRegister() | |||
| { | |||
| return loader; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,580 @@ | |||
| using ICSharpCode.SharpZipLib.Tar; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.ComponentModel; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Net; | |||
| using System.Net.Http; | |||
| using System.Net.Security; | |||
| using System.Security.Authentication; | |||
| using System.Threading.Tasks; | |||
| using System.Web; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| internal static class resolver | |||
| { | |||
| public enum ModelLoadFormat | |||
| { | |||
| [Description("COMPRESSED")] | |||
| COMPRESSED, | |||
| [Description("UNCOMPRESSED")] | |||
| UNCOMPRESSED, | |||
| [Description("AUTO")] | |||
| AUTO | |||
| } | |||
| public class DownloadManager | |||
| { | |||
| private readonly string _url; | |||
| private double _last_progress_msg_print_time; | |||
| private long _total_bytes_downloaded; | |||
| private int _max_prog_str; | |||
| private bool _interactive_mode() | |||
| { | |||
| return !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("_TFHUB_DOWNLOAD_PROGRESS")); | |||
| } | |||
| private void _print_download_progress_msg(string msg, bool flush = false) | |||
| { | |||
| if (_interactive_mode()) | |||
| { | |||
| // Print progress message to console overwriting previous progress | |||
| // message. | |||
| _max_prog_str = Math.Max(_max_prog_str, msg.Length); | |||
| Console.Write($"\r{msg.PadRight(_max_prog_str)}"); | |||
| Console.Out.Flush(); | |||
| //如果flush参数为true,则输出换行符减少干扰交互式界面。 | |||
| if (flush) | |||
| Console.WriteLine(); | |||
| } | |||
| else | |||
| { | |||
| // Interactive progress tracking is disabled. Print progress to the | |||
| // standard TF log. | |||
| tf.Logger.Information(msg); | |||
| } | |||
| } | |||
| private void _log_progress(long bytes_downloaded) | |||
| { | |||
| // Logs progress information about ongoing module download. | |||
| _total_bytes_downloaded += bytes_downloaded; | |||
| var now = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; | |||
| if (_interactive_mode() || now - _last_progress_msg_print_time > 15) | |||
| { | |||
| // Print progress message every 15 secs or if interactive progress | |||
| // tracking is enabled. | |||
| _print_download_progress_msg($"Downloading {_url}:" + | |||
| $"{tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true)}"); | |||
| _last_progress_msg_print_time = now; | |||
| } | |||
| } | |||
| public DownloadManager(string url) | |||
| { | |||
| _url = url; | |||
| _last_progress_msg_print_time = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; | |||
| _total_bytes_downloaded = 0; | |||
| _max_prog_str = 0; | |||
| } | |||
| public void download_and_uncompress(Stream fileobj, string dst_path) | |||
| { | |||
| // Streams the content for the 'fileobj' and stores the result in dst_path. | |||
| try | |||
| { | |||
| file_utils.extract_tarfile_to_destination(fileobj, dst_path, _log_progress); | |||
| var total_size_str = tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true); | |||
| _print_download_progress_msg($"Downloaded {_url}, Total size: {total_size_str}", flush: true); | |||
| } | |||
| catch (TarException ex) | |||
| { | |||
| throw new IOException($"{_url} does not appear to be a valid module. Inner message:{ex.Message}", ex); | |||
| } | |||
| } | |||
| } | |||
| private static Dictionary<string, string> _flags = new(); | |||
| private static readonly string _TFHUB_CACHE_DIR = "TFHUB_CACHE_DIR"; | |||
| private static readonly string _TFHUB_DOWNLOAD_PROGRESS = "TFHUB_DOWNLOAD_PROGRESS"; | |||
| private static readonly string _TFHUB_MODEL_LOAD_FORMAT = "TFHUB_MODEL_LOAD_FORMAT"; | |||
| private static readonly string _TFHUB_DISABLE_CERT_VALIDATION = "TFHUB_DISABLE_CERT_VALIDATION"; | |||
| private static readonly string _TFHUB_DISABLE_CERT_VALIDATION_VALUE = "true"; | |||
| static resolver() | |||
| { | |||
| set_new_flag("tfhub_model_load_format", "AUTO"); | |||
| set_new_flag("tfhub_cache_dir", null); | |||
| } | |||
| public static string model_load_format() | |||
| { | |||
| return get_env_setting(_TFHUB_MODEL_LOAD_FORMAT, "tfhub_model_load_format"); | |||
| } | |||
| public static string? get_env_setting(string env_var, string flag_name) | |||
| { | |||
| string value = System.Environment.GetEnvironmentVariable(env_var); | |||
| if (string.IsNullOrEmpty(value)) | |||
| { | |||
| if (_flags.ContainsKey(flag_name)) | |||
| { | |||
| return _flags[flag_name]; | |||
| } | |||
| else | |||
| { | |||
| return null; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| return value; | |||
| } | |||
| } | |||
| public static string tfhub_cache_dir(string default_cache_dir = null, bool use_temp = false) | |||
| { | |||
| var cache_dir = get_env_setting(_TFHUB_CACHE_DIR, "tfhub_cache_dir") ?? default_cache_dir; | |||
| if (string.IsNullOrWhiteSpace(cache_dir) && use_temp) | |||
| { | |||
| // Place all TF-Hub modules under <system's temp>/tfhub_modules. | |||
| cache_dir = Path.Combine(Path.GetTempPath(), "tfhub_modules"); | |||
| } | |||
| if (!string.IsNullOrWhiteSpace(cache_dir)) | |||
| { | |||
| Console.WriteLine("Using {0} to cache modules.", cache_dir); | |||
| } | |||
| return cache_dir; | |||
| } | |||
| public static string create_local_module_dir(string cache_dir, string module_name) | |||
| { | |||
| Directory.CreateDirectory(cache_dir); | |||
| return Path.Combine(cache_dir, module_name); | |||
| } | |||
| public static void set_new_flag(string name, string value) | |||
| { | |||
| string[] tokens = new string[] {_TFHUB_CACHE_DIR, _TFHUB_DISABLE_CERT_VALIDATION, | |||
| _TFHUB_DISABLE_CERT_VALIDATION_VALUE, _TFHUB_DOWNLOAD_PROGRESS, _TFHUB_MODEL_LOAD_FORMAT}; | |||
| if (!tokens.Contains(name)) | |||
| { | |||
| tf.Logger.Warning($"You are settinng a flag '{name}' that cannot be recognized. The flag you set" + | |||
| "may not affect anything in tensorflow.hub."); | |||
| } | |||
| _flags[name] = value; | |||
| } | |||
| public static string _merge_relative_path(string dstPath, string relPath) | |||
| { | |||
| return file_utils.merge_relative_path(dstPath, relPath); | |||
| } | |||
| public static string _module_descriptor_file(string moduleDir) | |||
| { | |||
| return $"{moduleDir}.descriptor.txt"; | |||
| } | |||
| public static void _write_module_descriptor_file(string handle, string moduleDir) | |||
| { | |||
| var readme = _module_descriptor_file(moduleDir); | |||
| var content = $"Module: {handle}\nDownload Time: {DateTime.Now}\nDownloader Hostname: {Environment.MachineName} (PID:{Process.GetCurrentProcess().Id})"; | |||
| tf_utils.atomic_write_string_to_file(readme, content, overwrite: true); | |||
| } | |||
| public static string _lock_file_contents(string task_uid) | |||
| { | |||
| return $"{Environment.MachineName}.{Process.GetCurrentProcess().Id}.{task_uid}"; | |||
| } | |||
| public static string _lock_filename(string moduleDir) | |||
| { | |||
| return tf_utils.absolute_path(moduleDir) + ".lock"; | |||
| } | |||
| private static string _module_dir(string lockFilename) | |||
| { | |||
| var path = Path.GetDirectoryName(Path.GetFullPath(lockFilename)); | |||
| if (!string.IsNullOrEmpty(path)) | |||
| { | |||
| return Path.Combine(path, "hub_modules"); | |||
| } | |||
| throw new Exception("Unable to resolve hub_modules directory from lock file name."); | |||
| } | |||
| private static string _task_uid_from_lock_file(string lockFilename) | |||
| { | |||
| // Returns task UID of the task that created a given lock file. | |||
| var lockstring = File.ReadAllText(lockFilename); | |||
| return lockstring.Split('.').Last(); | |||
| } | |||
| private static string _temp_download_dir(string moduleDir, string taskUid) | |||
| { | |||
| // Returns the name of a temporary directory to download module to. | |||
| return $"{Path.GetFullPath(moduleDir)}.{taskUid}.tmp"; | |||
| } | |||
| private static long _dir_size(string directory) | |||
| { | |||
| // Returns total size (in bytes) of the given 'directory'. | |||
| long size = 0; | |||
| foreach (var elem in Directory.EnumerateFileSystemEntries(directory)) | |||
| { | |||
| var stat = new FileInfo(elem); | |||
| size += stat.Length; | |||
| if ((stat.Attributes & FileAttributes.Directory) != 0) | |||
| size += _dir_size(stat.FullName); | |||
| } | |||
| return size; | |||
| } | |||
| public static long _locked_tmp_dir_size(string lockFilename) | |||
| { | |||
| //Returns the size of the temp dir pointed to by the given lock file. | |||
| var taskUid = _task_uid_from_lock_file(lockFilename); | |||
| try | |||
| { | |||
| return _dir_size(_temp_download_dir(_module_dir(lockFilename), taskUid)); | |||
| } | |||
| catch (DirectoryNotFoundException) | |||
| { | |||
| return 0; | |||
| } | |||
| } | |||
| private static void _wait_for_lock_to_disappear(string handle, string lockFile, double lockFileTimeoutSec) | |||
| { | |||
| long? lockedTmpDirSize = null; | |||
| var lockedTmpDirSizeCheckTime = DateTime.Now; | |||
| var lockFileContent = ""; | |||
| while (File.Exists(lockFile)) | |||
| { | |||
| try | |||
| { | |||
| Console.WriteLine($"Module '{handle}' already being downloaded by '{File.ReadAllText(lockFile)}'. Waiting."); | |||
| if ((DateTime.Now - lockedTmpDirSizeCheckTime).TotalSeconds > lockFileTimeoutSec) | |||
| { | |||
| var curLockedTmpDirSize = _locked_tmp_dir_size(lockFile); | |||
| var curLockFileContent = File.ReadAllText(lockFile); | |||
| if (curLockedTmpDirSize == lockedTmpDirSize && curLockFileContent == lockFileContent) | |||
| { | |||
| Console.WriteLine($"Deleting lock file {lockFile} due to inactivity."); | |||
| File.Delete(lockFile); | |||
| break; | |||
| } | |||
| lockedTmpDirSize = curLockedTmpDirSize; | |||
| lockedTmpDirSizeCheckTime = DateTime.Now; | |||
| lockFileContent = curLockFileContent; | |||
| } | |||
| } | |||
| catch (FileNotFoundException) | |||
| { | |||
| // Lock file or temp directory were deleted during check. Continue | |||
| // to check whether download succeeded or we need to start our own | |||
| // download. | |||
| } | |||
| System.Threading.Thread.Sleep(5000); | |||
| } | |||
| } | |||
| public static async Task<string> atomic_download_async( | |||
| string handle, | |||
| Func<string, string, Task> downloadFn, | |||
| string moduleDir, | |||
| int lock_file_timeout_sec = 10 * 60) | |||
| { | |||
| var lockFile = _lock_filename(moduleDir); | |||
| var taskUid = Guid.NewGuid().ToString("N"); | |||
| var lockContents = _lock_file_contents(taskUid); | |||
| var tmpDir = _temp_download_dir(moduleDir, taskUid); | |||
| // Function to check whether model has already been downloaded. | |||
| Func<bool> checkModuleExists = () => | |||
| Directory.Exists(moduleDir) && | |||
| Directory.EnumerateFileSystemEntries(moduleDir).Any(); | |||
| // Check whether the model has already been downloaded before locking | |||
| // the destination path. | |||
| if (checkModuleExists()) | |||
| { | |||
| return moduleDir; | |||
| } | |||
| // Attempt to protect against cases of processes being cancelled with | |||
| // KeyboardInterrupt by using a try/finally clause to remove the lock | |||
| // and tmp_dir. | |||
| while (true) | |||
| { | |||
| try | |||
| { | |||
| tf_utils.atomic_write_string_to_file(lockFile, lockContents, false); | |||
| // Must test condition again, since another process could have created | |||
| // the module and deleted the old lock file since last test. | |||
| if (checkModuleExists()) | |||
| { | |||
| // Lock file will be deleted in the finally-clause. | |||
| return moduleDir; | |||
| } | |||
| if (Directory.Exists(moduleDir)) | |||
| { | |||
| Directory.Delete(moduleDir, true); | |||
| } | |||
| break; // Proceed to downloading the module. | |||
| } | |||
| // These errors are believed to be permanent problems with the | |||
| // module_dir that justify failing the download. | |||
| catch (FileNotFoundException) | |||
| { | |||
| throw; | |||
| } | |||
| catch (UnauthorizedAccessException) | |||
| { | |||
| throw; | |||
| } | |||
| catch (IOException) | |||
| { | |||
| throw; | |||
| } | |||
| // All other errors are retried. | |||
| // TODO(b/144424849): Retrying an AlreadyExistsError from the atomic write | |||
| // should be good enough, but see discussion about misc filesystem types. | |||
| // TODO(b/144475403): How atomic is the overwrite=False check? | |||
| catch (Exception) | |||
| { | |||
| } | |||
| // Wait for lock file to disappear. | |||
| _wait_for_lock_to_disappear(handle, lockFile, lock_file_timeout_sec); | |||
| // At this point we either deleted a lock or a lock got removed by the | |||
| // owner or another process. Perform one more iteration of the while-loop, | |||
| // we would either terminate due tf.compat.v1.gfile.Exists(module_dir) or | |||
| // because we would obtain a lock ourselves, or wait again for the lock to | |||
| // disappear. | |||
| } | |||
| // Lock file acquired. | |||
| tf.Logger.Information($"Downloading TF-Hub Module '{handle}'..."); | |||
| Directory.CreateDirectory(tmpDir); | |||
| await downloadFn(handle, tmpDir); | |||
| // Write module descriptor to capture information about which module was | |||
| // downloaded by whom and when. The file stored at the same level as a | |||
| // directory in order to keep the content of the 'model_dir' exactly as it | |||
| // was define by the module publisher. | |||
| // | |||
| // Note: The descriptor is written purely to help the end-user to identify | |||
| // which directory belongs to which module. The descriptor is not part of the | |||
| // module caching protocol and no code in the TF-Hub library reads its | |||
| // content. | |||
| _write_module_descriptor_file(handle, moduleDir); | |||
| try | |||
| { | |||
| Directory.Move(tmpDir, moduleDir); | |||
| Console.WriteLine($"Downloaded TF-Hub Module '{handle}'."); | |||
| } | |||
| catch (IOException e) | |||
| { | |||
| Console.WriteLine(e.Message); | |||
| Console.WriteLine($"Failed to move {tmpDir} to {moduleDir}"); | |||
| // Keep the temp directory so we will retry building vocabulary later. | |||
| } | |||
| // Temp directory is owned by the current process, remove it. | |||
| try | |||
| { | |||
| Directory.Delete(tmpDir, true); | |||
| } | |||
| catch (DirectoryNotFoundException) | |||
| { | |||
| } | |||
| // Lock file exists and is owned by this process. | |||
| try | |||
| { | |||
| var contents = File.ReadAllText(lockFile); | |||
| if (contents == lockContents) | |||
| { | |||
| File.Delete(lockFile); | |||
| } | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| return moduleDir; | |||
| } | |||
| } | |||
| internal interface IResolver | |||
| { | |||
| string Call(string handle); | |||
| bool IsSupported(string handle); | |||
| } | |||
| internal class PathResolver : IResolver | |||
| { | |||
| public string Call(string handle) | |||
| { | |||
| if (!File.Exists(handle) && !Directory.Exists(handle)) | |||
| { | |||
| throw new IOException($"{handle} does not exist in file system."); | |||
| } | |||
| return handle; | |||
| } | |||
| public bool IsSupported(string handle) | |||
| { | |||
| return true; | |||
| } | |||
| } | |||
| public abstract class HttpResolverBase : IResolver | |||
| { | |||
| private readonly HttpClient httpClient; | |||
| private SslProtocol sslProtocol; | |||
| private RemoteCertificateValidationCallback certificateValidator; | |||
| protected HttpResolverBase() | |||
| { | |||
| httpClient = new HttpClient(); | |||
| _maybe_disable_cert_validation(); | |||
| } | |||
| public abstract string Call(string handle); | |||
| public abstract bool IsSupported(string handle); | |||
| protected async Task<Stream> GetLocalFileStreamAsync(string filePath) | |||
| { | |||
| try | |||
| { | |||
| var fs = new FileStream(filePath, FileMode.Open, FileAccess.Read); | |||
| return await Task.FromResult(fs); | |||
| } | |||
| catch (Exception ex) | |||
| { | |||
| Console.WriteLine($"Failed to read file stream: {ex.Message}"); | |||
| return null; | |||
| } | |||
| } | |||
| protected async Task<Stream> GetFileStreamAsync(string filePath) | |||
| { | |||
| if (!is_http_protocol(filePath)) | |||
| { | |||
| // If filePath is not an HTTP(S) URL, delegate to a file resolver. | |||
| return await GetLocalFileStreamAsync(filePath); | |||
| } | |||
| var request = new HttpRequestMessage(HttpMethod.Get, filePath); | |||
| var response = await _call_urlopen(request); | |||
| if (response.IsSuccessStatusCode) | |||
| { | |||
| return await response.Content.ReadAsStreamAsync(); | |||
| } | |||
| else | |||
| { | |||
| Console.WriteLine($"Failed to fetch file stream: {response.StatusCode} - {response.ReasonPhrase}"); | |||
| return null; | |||
| } | |||
| } | |||
| protected void SetUrlContext(SslProtocol protocol, RemoteCertificateValidationCallback validator) | |||
| { | |||
| sslProtocol = protocol; | |||
| certificateValidator = validator; | |||
| } | |||
| public static string append_format_query(string handle, (string, string) formatQuery) | |||
| { | |||
| var parsed = new Uri(handle); | |||
| var queryBuilder = HttpUtility.ParseQueryString(parsed.Query); | |||
| queryBuilder.Add(formatQuery.Item1, formatQuery.Item2); | |||
| parsed = new UriBuilder(parsed.Scheme, parsed.Host, parsed.Port, parsed.AbsolutePath, | |||
| "?" + queryBuilder.ToString()).Uri; | |||
| return parsed.ToString(); | |||
| } | |||
| protected bool is_http_protocol(string handle) | |||
| { | |||
| return handle.StartsWith("http://") || handle.StartsWith("https://"); | |||
| } | |||
| protected async Task<HttpResponseMessage> _call_urlopen(HttpRequestMessage request) | |||
| { | |||
| if (sslProtocol != null) | |||
| { | |||
| var handler = new HttpClientHandler() | |||
| { | |||
| SslProtocols = sslProtocol.AsEnum(), | |||
| }; | |||
| if (certificateValidator != null) | |||
| { | |||
| handler.ServerCertificateCustomValidationCallback = (x, y, z, w) => | |||
| { | |||
| return certificateValidator(x, y, z, w); | |||
| }; | |||
| } | |||
| var client = new HttpClient(handler); | |||
| return await client.SendAsync(request); | |||
| } | |||
| else | |||
| { | |||
| return await httpClient.SendAsync(request); | |||
| } | |||
| } | |||
| protected void _maybe_disable_cert_validation() | |||
| { | |||
| if (Environment.GetEnvironmentVariable("_TFHUB_DISABLE_CERT_VALIDATION") == "_TFHUB_DISABLE_CERT_VALIDATION_VALUE") | |||
| { | |||
| ServicePointManager.ServerCertificateValidationCallback = (_, _, _, _) => true; | |||
| Console.WriteLine("Disabled certificate validation for resolving handles."); | |||
| } | |||
| } | |||
| } | |||
| public class SslProtocol | |||
| { | |||
| private readonly string protocolString; | |||
| public static readonly SslProtocol Tls = new SslProtocol("TLS"); | |||
| public static readonly SslProtocol Tls11 = new SslProtocol("TLS 1.1"); | |||
| public static readonly SslProtocol Tls12 = new SslProtocol("TLS 1.2"); | |||
| private SslProtocol(string protocolString) | |||
| { | |||
| this.protocolString = protocolString; | |||
| } | |||
| public SslProtocols AsEnum() | |||
| { | |||
| switch (protocolString.ToUpper()) | |||
| { | |||
| case "TLS": | |||
| return SslProtocols.Tls; | |||
| case "TLS 1.1": | |||
| return SslProtocols.Tls11; | |||
| case "TLS 1.2": | |||
| return SslProtocols.Tls12; | |||
| default: | |||
| throw new ArgumentException($"Unknown SSL/TLS protocol: {protocolString}"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,80 @@ | |||
| using System; | |||
| using System.IO; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| internal class tf_utils | |||
| { | |||
| public static string bytes_to_readable_str(long? numBytes, bool includeB = false) | |||
| { | |||
| if (numBytes == null) return numBytes.ToString(); | |||
| var num = (double)numBytes; | |||
| if (num < 1024) | |||
| { | |||
| return $"{(long)num}{(includeB ? "B" : "")}"; | |||
| } | |||
| num /= 1 << 10; | |||
| if (num < 1024) | |||
| { | |||
| return $"{num:F2}k{(includeB ? "B" : "")}"; | |||
| } | |||
| num /= 1 << 10; | |||
| if (num < 1024) | |||
| { | |||
| return $"{num:F2}M{(includeB ? "B" : "")}"; | |||
| } | |||
| num /= 1 << 10; | |||
| return $"{num:F2}G{(includeB ? "B" : "")}"; | |||
| } | |||
| public static void atomic_write_string_to_file(string filename, string contents, bool overwrite) | |||
| { | |||
| var tempPath = $"{filename}.tmp.{Guid.NewGuid():N}"; | |||
| using (var fileStream = new FileStream(tempPath, FileMode.Create)) | |||
| { | |||
| using (var writer = new StreamWriter(fileStream)) | |||
| { | |||
| writer.Write(contents); | |||
| writer.Flush(); | |||
| } | |||
| } | |||
| try | |||
| { | |||
| if (File.Exists(filename)) | |||
| { | |||
| if (overwrite) | |||
| { | |||
| File.Delete(filename); | |||
| File.Move(tempPath, filename); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| File.Move(tempPath, filename); | |||
| } | |||
| } | |||
| catch | |||
| { | |||
| File.Delete(tempPath); | |||
| throw; | |||
| } | |||
| } | |||
| public static string absolute_path(string path) | |||
| { | |||
| if (path.Contains("://")) | |||
| { | |||
| return path; | |||
| } | |||
| return Path.GetFullPath(path); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,46 @@ | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.HubAPI; | |||
| namespace Tensorflow.Hub.Unittest | |||
| { | |||
| [TestClass] | |||
| public class KerasLayerTest | |||
| { | |||
| [TestMethod] | |||
| public void SmallBert() | |||
| { | |||
| var layer = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1"); | |||
| var input_type_ids = tf.convert_to_tensor(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); | |||
| input_type_ids = tf.reshape(input_type_ids, (1, 128)); | |||
| var input_word_ids = tf.convert_to_tensor(new int[] { 101, 2129, 2024, 2017, 102, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); | |||
| input_word_ids = tf.reshape(input_word_ids, (1, 128)); | |||
| var input_mask = tf.convert_to_tensor(new int[] { 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: dtypes.int32); | |||
| input_mask = tf.reshape(input_mask, (1, 128)); | |||
| var result = layer.Apply(new Tensors(input_type_ids, input_word_ids, input_mask)); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,23 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <PropertyGroup> | |||
| <TargetFramework>net7</TargetFramework> | |||
| <ImplicitUsings>enable</ImplicitUsings> | |||
| <Nullable>enable</Nullable> | |||
| <IsPackable>false</IsPackable> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" /> | |||
| <PackageReference Include="MSTest.TestAdapter" Version="2.2.10" /> | |||
| <PackageReference Include="MSTest.TestFramework" Version="2.2.10" /> | |||
| <PackageReference Include="coverlet.collector" Version="3.1.2" /> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.11.2" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\src\TensorflowNET.Hub\Tensorflow.Hub.csproj" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -0,0 +1 @@ | |||
| global using Microsoft.VisualStudio.TestTools.UnitTesting; | |||