| @@ -15,14 +15,15 @@ namespace Tensorflow.Hub | |||||
| private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | ||||
| private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | ||||
| public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null) | |||||
| public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||||
| { | { | ||||
| var loader = new MnistModelLoader(); | var loader = new MnistModelLoader(); | ||||
| var setting = new ModelLoadSetting | var setting = new ModelLoadSetting | ||||
| { | { | ||||
| TrainDir = trainDir, | TrainDir = trainDir, | ||||
| OneHot = oneHot | |||||
| OneHot = oneHot, | |||||
| ShowProgressInConsole = showProgressInConsole | |||||
| }; | }; | ||||
| if (trainSize.HasValue) | if (trainSize.HasValue) | ||||
| @@ -48,37 +49,37 @@ namespace Tensorflow.Hub | |||||
| sourceUrl = DEFAULT_SOURCE_URL; | sourceUrl = DEFAULT_SOURCE_URL; | ||||
| // load train images | // load train images | ||||
| await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES) | |||||
| await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir) | |||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); | var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); | ||||
| // load train labels | // load train labels | ||||
| await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS) | |||||
| await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir) | |||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); | var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); | ||||
| // load test images | // load test images | ||||
| await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES) | |||||
| await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir) | |||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); | var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); | ||||
| // load test labels | // load test labels | ||||
| await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS) | |||||
| await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir) | |||||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | .ShowProgressInConsole(setting.ShowProgressInConsole); | ||||
| var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); | var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); | ||||
| @@ -2,7 +2,7 @@ | |||||
| <PropertyGroup> | <PropertyGroup> | ||||
| <RootNamespace>Tensorflow.Hub</RootNamespace> | <RootNamespace>Tensorflow.Hub</RootNamespace> | ||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| <Version>0.0.1</Version> | |||||
| <Version>0.0.2</Version> | |||||
| <Authors>Kerry Jiang</Authors> | <Authors>Kerry Jiang</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| <Copyright>Apache 2.0</Copyright> | <Copyright>Apache 2.0</Copyright> | ||||
| @@ -13,7 +13,7 @@ | |||||
| <Description>TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models.</Description> | <Description>TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models.</Description> | ||||
| <PackageId>SciSharp.TensorFlowHub</PackageId> | <PackageId>SciSharp.TensorFlowHub</PackageId> | ||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| <PackageReleaseNotes>1. Add MNIST loader.</PackageReleaseNotes> | |||||
| <PackageReleaseNotes></PackageReleaseNotes> | |||||
| <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -19,7 +19,7 @@ namespace Tensorflow.Hub | |||||
| await modelLoader.DownloadAsync(url, dir, fileName); | await modelLoader.DownloadAsync(url, dir, fileName); | ||||
| } | } | ||||
| public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName) | |||||
| public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false) | |||||
| where TDataSet : IDataSet | where TDataSet : IDataSet | ||||
| { | { | ||||
| if (!Path.IsPathRooted(dirSaveTo)) | if (!Path.IsPathRooted(dirSaveTo)) | ||||
| @@ -27,18 +27,30 @@ namespace Tensorflow.Hub | |||||
| var fileSaveTo = Path.Combine(dirSaveTo, fileName); | var fileSaveTo = Path.Combine(dirSaveTo, fileName); | ||||
| if (showProgressInConsole) | |||||
| { | |||||
| Console.WriteLine($"Downloading {fileName}"); | |||||
| } | |||||
| if (File.Exists(fileSaveTo)) | if (File.Exists(fileSaveTo)) | ||||
| { | |||||
| if (showProgressInConsole) | |||||
| { | |||||
| Console.WriteLine($"The file {fileName} already exists"); | |||||
| } | |||||
| return; | return; | ||||
| } | |||||
| Directory.CreateDirectory(dirSaveTo); | Directory.CreateDirectory(dirSaveTo); | ||||
| using (var wc = new WebClient()) | using (var wc = new WebClient()) | ||||
| { | { | ||||
| await wc.DownloadFileTaskAsync(url, fileSaveTo); | |||||
| await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false); | |||||
| } | } | ||||
| } | } | ||||
| public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo) | |||||
| public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false) | |||||
| where TDataSet : IDataSet | where TDataSet : IDataSet | ||||
| { | { | ||||
| if (!Path.IsPathRooted(saveTo)) | if (!Path.IsPathRooted(saveTo)) | ||||
| @@ -49,67 +61,76 @@ namespace Tensorflow.Hub | |||||
| if (!Path.IsPathRooted(zipFile)) | if (!Path.IsPathRooted(zipFile)) | ||||
| zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); | zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); | ||||
| var destFilePath = Path.Combine(saveTo, Path.GetFileNameWithoutExtension(zipFile)); | |||||
| var destFileName = Path.GetFileNameWithoutExtension(zipFile); | |||||
| var destFilePath = Path.Combine(saveTo, destFileName); | |||||
| if (showProgressInConsole) | |||||
| Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); | |||||
| if (File.Exists(destFilePath)) | if (File.Exists(destFilePath)) | ||||
| File.Delete(destFilePath); | |||||
| { | |||||
| if (showProgressInConsole) | |||||
| Console.WriteLine($"The file {destFileName} already exists"); | |||||
| } | |||||
| using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) | using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) | ||||
| { | { | ||||
| using (var destStream = File.Create(destFilePath)) | using (var destStream = File.Create(destFilePath)) | ||||
| { | { | ||||
| await unzipStream.CopyToAsync(destStream); | |||||
| await destStream.FlushAsync(); | |||||
| await unzipStream.CopyToAsync(destStream).ConfigureAwait(false); | |||||
| await destStream.FlushAsync().ConfigureAwait(false); | |||||
| destStream.Close(); | destStream.Close(); | ||||
| } | } | ||||
| unzipStream.Close(); | unzipStream.Close(); | ||||
| } | } | ||||
| } | |||||
| public static async Task ShowProgressInConsole(this Task task) | |||||
| { | |||||
| await ShowProgressInConsole(task, true); | |||||
| } | |||||
| } | |||||
| public static async Task ShowProgressInConsole(this Task task, bool enable) | public static async Task ShowProgressInConsole(this Task task, bool enable) | ||||
| { | { | ||||
| if (!enable) | if (!enable) | ||||
| { | { | ||||
| await task; | await task; | ||||
| return; | |||||
| } | } | ||||
| var cts = new CancellationTokenSource(); | var cts = new CancellationTokenSource(); | ||||
| var showProgressTask = ShowProgressInConsole(cts); | var showProgressTask = ShowProgressInConsole(cts); | ||||
| try | try | ||||
| { | |||||
| { | |||||
| await task; | await task; | ||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| cts.Cancel(); | |||||
| cts.Cancel(); | |||||
| } | } | ||||
| await showProgressTask; | |||||
| Console.WriteLine("Done."); | |||||
| } | } | ||||
| private static async Task ShowProgressInConsole(CancellationTokenSource cts) | private static async Task ShowProgressInConsole(CancellationTokenSource cts) | ||||
| { | { | ||||
| var cols = 0; | var cols = 0; | ||||
| await Task.Delay(1000); | |||||
| while (!cts.IsCancellationRequested) | while (!cts.IsCancellationRequested) | ||||
| { | { | ||||
| await Task.Delay(1000); | await Task.Delay(1000); | ||||
| Console.Write("."); | Console.Write("."); | ||||
| cols++; | cols++; | ||||
| if (cols >= 50) | |||||
| if (cols % 50 == 0) | |||||
| { | { | ||||
| cols = 0; | |||||
| Console.WriteLine(); | Console.WriteLine(); | ||||
| } | } | ||||
| } | } | ||||
| Console.WriteLine(); | |||||
| if (cols > 0) | |||||
| Console.WriteLine(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -192,6 +192,12 @@ namespace Tensorflow | |||||
| public static Tensor logical_and(Tensor x, Tensor y, string name = null) | public static Tensor logical_and(Tensor x, Tensor y, string name = null) | ||||
| => gen_math_ops.logical_and(x, y, name); | => gen_math_ops.logical_and(x, y, name); | ||||
| public static Tensor logical_not(Tensor x, string name = null) | |||||
| => gen_math_ops.logical_not(x, name); | |||||
| public static Tensor logical_or(Tensor x, Tensor y, string name = null) | |||||
| => gen_math_ops.logical_or(x, y, name); | |||||
| /// <summary> | /// <summary> | ||||
| /// Clips tensor values to a specified min and max. | /// Clips tensor values to a specified min and max. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -34,10 +34,17 @@ namespace Tensorflow | |||||
| public Graph get_controller() | public Graph get_controller() | ||||
| { | { | ||||
| if (stack.Count == 0) | |||||
| if (stack.Count(x => x.IsDefault) == 0) | |||||
| stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | ||||
| return stack.First(x => x.IsDefault).Graph; | |||||
| return stack.Last(x => x.IsDefault).Graph; | |||||
| } | |||||
| public bool remove(Graph g) | |||||
| { | |||||
| var sm = stack.FirstOrDefault(x => x.Graph == g); | |||||
| if (sm == null) return false; | |||||
| return stack.Remove(sm); | |||||
| } | } | ||||
| public void reset() | public void reset() | ||||
| @@ -73,9 +73,8 @@ namespace Tensorflow | |||||
| all variables that are created during the construction of a graph. The caller | all variables that are created during the construction of a graph. The caller | ||||
| may define additional collections by specifying a new name. | may define additional collections by specifying a new name. | ||||
| */ | */ | ||||
| public partial class Graph : IPython, IDisposable, IEnumerable<Operation> | |||||
| public partial class Graph : DisposableObject, IEnumerable<Operation> | |||||
| { | { | ||||
| private IntPtr _handle; | |||||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | private Dictionary<int, ITensorOrOperation> _nodes_by_id; | ||||
| public Dictionary<string, ITensorOrOperation> _nodes_by_name; | public Dictionary<string, ITensorOrOperation> _nodes_by_name; | ||||
| private Dictionary<string, int> _names_in_use; | private Dictionary<string, int> _names_in_use; | ||||
| @@ -121,10 +120,6 @@ namespace Tensorflow | |||||
| _graph_key = $"grap-key-{ops.uid()}/"; | _graph_key = $"grap-key-{ops.uid()}/"; | ||||
| } | } | ||||
| public void __enter__() | |||||
| { | |||||
| } | |||||
| public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | ||||
| { | { | ||||
| return _as_graph_element_locked(obj, allow_tensor, allow_operation); | return _as_graph_element_locked(obj, allow_tensor, allow_operation); | ||||
| @@ -443,14 +438,15 @@ namespace Tensorflow | |||||
| _unfetchable_ops.Add(op); | _unfetchable_ops.Add(op); | ||||
| } | } | ||||
| public void Dispose() | |||||
| { | |||||
| /*if (_handle != IntPtr.Zero) | |||||
| c_api.TF_DeleteGraph(_handle); | |||||
| _handle = IntPtr.Zero; | |||||
| GC.SuppressFinalize(this);*/ | |||||
| protected override void DisposeManagedState() | |||||
| { | |||||
| ops.default_graph_stack.remove(this); | |||||
| } | |||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| { | |||||
| Console.WriteLine($"Destroy graph {handle}"); | |||||
| c_api.TF_DeleteGraph(handle); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -481,17 +477,19 @@ namespace Tensorflow | |||||
| return new TensorShape(dims.Select(x => (int)x).ToArray()); | return new TensorShape(dims.Select(x => (int)x).ToArray()); | ||||
| } | } | ||||
| string debugString = string.Empty; | |||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| int len = 0; | |||||
| return c_api.TF_GraphDebugString(_handle, out len); | |||||
| return $"{graph_key}, ({_handle})"; | |||||
| /*if (string.IsNullOrEmpty(debugString)) | |||||
| { | |||||
| int len = 0; | |||||
| debugString = c_api.TF_GraphDebugString(_handle, out len); | |||||
| } | |||||
| return debugString;*/ | |||||
| } | } | ||||
| public void __exit__() | |||||
| { | |||||
| } | |||||
| private IEnumerable<Operation> GetEnumerable() | private IEnumerable<Operation> GetEnumerable() | ||||
| => c_api_util.tf_operations(this); | => c_api_util.tf_operations(this); | ||||
| @@ -84,7 +84,7 @@ namespace Tensorflow | |||||
| // Dict mapping op name to file and line information for op colocation | // Dict mapping op name to file and line information for op colocation | ||||
| // context managers. | // context managers. | ||||
| _control_flow_context = graph._get_control_flow_context(); | |||||
| _control_flow_context = _graph._get_control_flow_context(); | |||||
| // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | ||||
| } | } | ||||
| @@ -357,6 +357,20 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor logical_not(Tensor x, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("LogicalNot", name, args: new { x }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor logical_or(Tensor x, Tensor y, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("LogicalOr", name, args: new { x, y }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor squared_difference(Tensor x, Tensor y, string name = null) | public static Tensor squared_difference(Tensor x, Tensor y, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("SquaredDifference", name, args: new { x, y, name }); | var _op = _op_def_lib._apply_op_helper("SquaredDifference", name, args: new { x, y, name }); | ||||
| @@ -31,7 +31,6 @@ namespace Tensorflow | |||||
| protected bool _closed; | protected bool _closed; | ||||
| protected int _current_version; | protected int _current_version; | ||||
| protected byte[] _target; | protected byte[] _target; | ||||
| protected IntPtr _session; | |||||
| public Graph graph => _graph; | public Graph graph => _graph; | ||||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | ||||
| @@ -46,7 +45,7 @@ namespace Tensorflow | |||||
| var status = new Status(); | var status = new Status(); | ||||
| _session = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| @@ -212,7 +211,7 @@ namespace Tensorflow | |||||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | ||||
| c_api.TF_SessionRun(_session, | |||||
| c_api.TF_SessionRun(_handle, | |||||
| run_options: null, | run_options: null, | ||||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | inputs: feed_dict.Select(f => f.Key).ToArray(), | ||||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | ||||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||||
| public Session(IntPtr handle, Graph g = null) | public Session(IntPtr handle, Graph g = null) | ||||
| : base("", g, null) | : base("", g, null) | ||||
| { | { | ||||
| _session = handle; | |||||
| _handle = handle; | |||||
| } | } | ||||
| public Session(Graph g, SessionOptions opts = null, Status s = null) | public Session(Graph g, SessionOptions opts = null, Status s = null) | ||||
| @@ -73,7 +73,7 @@ namespace Tensorflow | |||||
| return new Session(sess, g: new Graph(graph).as_default()); | return new Session(sess, g: new Graph(graph).as_default()); | ||||
| } | } | ||||
| public static implicit operator IntPtr(Session session) => session._session; | |||||
| public static implicit operator IntPtr(Session session) => session._handle; | |||||
| public static implicit operator Session(IntPtr handle) => new Session(handle); | public static implicit operator Session(IntPtr handle) => new Session(handle); | ||||
| public void __enter__() | public void __enter__() | ||||
| @@ -506,7 +506,7 @@ namespace Tensorflow | |||||
| IsMemoryOwner = true; | IsMemoryOwner = true; | ||||
| } | } | ||||
| private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null) | |||||
| private unsafe IntPtr AllocateWithMemoryCopy(NDArray nd, TF_DataType? tensorDType = null) | |||||
| { | { | ||||
| IntPtr dotHandle = IntPtr.Zero; | IntPtr dotHandle = IntPtr.Zero; | ||||
| int buffersize = 0; | int buffersize = 0; | ||||
| @@ -520,30 +520,30 @@ namespace Tensorflow | |||||
| var dataType = ToTFDataType(nd.dtype); | var dataType = ToTFDataType(nd.dtype); | ||||
| // shape | // shape | ||||
| var dims = nd.shape.Select(x => (long)x).ToArray(); | var dims = nd.shape.Select(x => (long)x).ToArray(); | ||||
| var nd1 = nd.ravel(); | |||||
| // var nd1 = nd.ravel(); | |||||
| switch (nd.dtype.Name) | switch (nd.dtype.Name) | ||||
| { | { | ||||
| case "Boolean": | case "Boolean": | ||||
| var boolVals = Array.ConvertAll(nd1.Data<bool>(), x => Convert.ToByte(x)); | |||||
| var boolVals = Array.ConvertAll(nd.Data<bool>(), x => Convert.ToByte(x)); | |||||
| Marshal.Copy(boolVals, 0, dotHandle, nd.size); | Marshal.Copy(boolVals, 0, dotHandle, nd.size); | ||||
| break; | break; | ||||
| case "Int16": | case "Int16": | ||||
| Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size); | |||||
| Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size); | |||||
| break; | break; | ||||
| case "Int32": | case "Int32": | ||||
| Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size); | |||||
| Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | |||||
| break; | break; | ||||
| case "Int64": | case "Int64": | ||||
| Marshal.Copy(nd1.Data<long>(), 0, dotHandle, nd.size); | |||||
| Marshal.Copy(nd.Data<long>(), 0, dotHandle, nd.size); | |||||
| break; | break; | ||||
| case "Single": | case "Single": | ||||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||||
| Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size); | |||||
| break; | break; | ||||
| case "Double": | case "Double": | ||||
| Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); | |||||
| Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | |||||
| break; | break; | ||||
| case "Byte": | case "Byte": | ||||
| Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size); | |||||
| Marshal.Copy(nd.Data<byte>(), 0, dotHandle, nd.size); | |||||
| break; | break; | ||||
| case "String": | case "String": | ||||
| return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)), TF_DataType.TF_STRING); | return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)), TF_DataType.TF_STRING); | ||||
| @@ -559,6 +559,132 @@ namespace Tensorflow | |||||
| ref _deallocatorArgs); | ref _deallocatorArgs); | ||||
| return tfHandle; | return tfHandle; | ||||
| } | |||||
| private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null) | |||||
| { | |||||
| IntPtr dotHandle = IntPtr.Zero; | |||||
| IntPtr tfHandle = IntPtr.Zero; | |||||
| int buffersize = nd.size * nd.dtypesize; | |||||
| var dataType = ToTFDataType(nd.dtype); | |||||
| // shape | |||||
| var dims = nd.shape.Select(x => (long)x).ToArray(); | |||||
| switch (nd.dtype.Name) | |||||
| { | |||||
| case "Boolean": | |||||
| { | |||||
| var boolVals = Array.ConvertAll(nd.Data<bool>(), x => Convert.ToByte(x)); | |||||
| var array = nd.Data<byte>(); | |||||
| fixed (byte* h = &array[0]) | |||||
| { | |||||
| tfHandle = c_api.TF_NewTensor(dataType, | |||||
| dims, | |||||
| dims.Length, | |||||
| new IntPtr(h), | |||||
| (UIntPtr)buffersize, | |||||
| _nothingDeallocator, | |||||
| ref _deallocatorArgs); | |||||
| } | |||||
| } | |||||
| break; | |||||
| case "Int16": | |||||
| { | |||||
| var array = nd.Data<short>(); | |||||
| fixed (short* h = &array[0]) | |||||
| { | |||||
| tfHandle = c_api.TF_NewTensor(dataType, | |||||
| dims, | |||||
| dims.Length, | |||||
| new IntPtr(h), | |||||
| (UIntPtr)buffersize, | |||||
| _nothingDeallocator, | |||||
| ref _deallocatorArgs); | |||||
| } | |||||
| } | |||||
| break; | |||||
| case "Int32": | |||||
| { | |||||
| var array = nd.Data<int>(); | |||||
| fixed (int* h = &array[0]) | |||||
| { | |||||
| tfHandle = c_api.TF_NewTensor(dataType, | |||||
| dims, | |||||
| dims.Length, | |||||
| new IntPtr(h), | |||||
| (UIntPtr)buffersize, | |||||
| _nothingDeallocator, | |||||
| ref _deallocatorArgs); | |||||
| } | |||||
| } | |||||
| break; | |||||
| case "Int64": | |||||
| { | |||||
| var array = nd.Data<long>(); | |||||
| fixed (long* h = &array[0]) | |||||
| { | |||||
| tfHandle = c_api.TF_NewTensor(dataType, | |||||
| dims, | |||||
| dims.Length, | |||||
| new IntPtr(h), | |||||
| (UIntPtr)buffersize, | |||||
| _nothingDeallocator, | |||||
| ref _deallocatorArgs); | |||||
| } | |||||
| } | |||||
| break; | |||||
| case "Single": | |||||
| { | |||||
| var array = nd.Data<float>(); | |||||
| fixed (float* h = &array[0]) | |||||
| { | |||||
| tfHandle = c_api.TF_NewTensor(dataType, | |||||
| dims, | |||||
| dims.Length, | |||||
| new IntPtr(h), | |||||
| (UIntPtr)buffersize, | |||||
| _nothingDeallocator, | |||||
| ref _deallocatorArgs); | |||||
| } | |||||
| } | |||||
| break; | |||||
| case "Double": | |||||
| { | |||||
| var array = nd.Data<double>(); | |||||
| fixed (double* h = &array[0]) | |||||
| { | |||||
| tfHandle = c_api.TF_NewTensor(dataType, | |||||
| dims, | |||||
| dims.Length, | |||||
| new IntPtr(h), | |||||
| (UIntPtr)buffersize, | |||||
| _nothingDeallocator, | |||||
| ref _deallocatorArgs); | |||||
| } | |||||
| } | |||||
| break; | |||||
| case "Byte": | |||||
| { | |||||
| var array = nd.Data<byte>(); | |||||
| fixed (byte* h = &array[0]) | |||||
| { | |||||
| tfHandle = c_api.TF_NewTensor(dataType, | |||||
| dims, | |||||
| dims.Length, | |||||
| new IntPtr(h), | |||||
| (UIntPtr)buffersize, | |||||
| _nothingDeallocator, | |||||
| ref _deallocatorArgs); | |||||
| } | |||||
| } | |||||
| break; | |||||
| case "String": | |||||
| return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)), TF_DataType.TF_STRING); | |||||
| default: | |||||
| throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); | |||||
| } | |||||
| return tfHandle; | |||||
| } | } | ||||
| public unsafe Tensor(byte[][] buffer, long[] shape) | public unsafe Tensor(byte[][] buffer, long[] shape) | ||||
| @@ -70,7 +70,8 @@ namespace TensorFlowNET.Examples | |||||
| OneHot = true, | OneHot = true, | ||||
| TrainSize = train_size, | TrainSize = train_size, | ||||
| ValidationSize = validation_size, | ValidationSize = validation_size, | ||||
| TestSize = test_size | |||||
| TestSize = test_size, | |||||
| ShowProgressInConsole = true | |||||
| }; | }; | ||||
| mnist = loader.LoadAsync(setting).Result; | mnist = loader.LoadAsync(setting).Result; | ||||
| @@ -124,7 +124,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result; | |||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size, showProgressInConsole: true).Result; | |||||
| } | } | ||||
| public void SaveModel(Session sess) | public void SaveModel(Session sess) | ||||
| @@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result; | |||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize, showProgressInConsole: true).Result; | |||||
| // In this example, we limit mnist data | // In this example, we limit mnist data | ||||
| (Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | (Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | ||||
| (Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing | (Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing | ||||
| @@ -310,7 +310,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||||
| (x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); | (x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); | ||||
| (x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); | (x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); | ||||
| (x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); | (x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); | ||||
| @@ -121,7 +121,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||||
| } | } | ||||
| public void Train(Session sess) | public void Train(Session sess) | ||||
| @@ -143,7 +143,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||||
| mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||||
| (x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); | (x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); | ||||
| (x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); | (x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); | ||||
| (x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels); | (x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels); | ||||
| @@ -52,7 +52,8 @@ namespace TensorFlowNET.Examples | |||||
| // The location where variable checkpoints will be stored. | // The location where variable checkpoints will be stored. | ||||
| string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint"); | string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint"); | ||||
| string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3"; | string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3"; | ||||
| string final_tensor_name = "final_result"; | |||||
| string input_tensor_name = "Placeholder"; | |||||
| string final_tensor_name = "Score"; | |||||
| float testing_percentage = 0.1f; | float testing_percentage = 0.1f; | ||||
| float validation_percentage = 0.1f; | float validation_percentage = 0.1f; | ||||
| float learning_rate = 0.01f; | float learning_rate = 0.01f; | ||||
| @@ -81,13 +82,13 @@ namespace TensorFlowNET.Examples | |||||
| PrepareData(); | PrepareData(); | ||||
| #region For debug purpose | #region For debug purpose | ||||
| // predict images | // predict images | ||||
| // Predict(null); | // Predict(null); | ||||
| // load saved pb and test new images. | // load saved pb and test new images. | ||||
| // Test(null); | // Test(null); | ||||
| #endregion | #endregion | ||||
| var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | ||||
| @@ -276,16 +277,13 @@ namespace TensorFlowNET.Examples | |||||
| private (Graph, Tensor, Tensor, bool) create_module_graph() | private (Graph, Tensor, Tensor, bool) create_module_graph() | ||||
| { | { | ||||
| var (height, width) = (299, 299); | var (height, width) = (299, 299); | ||||
| return tf_with(tf.Graph().as_default(), graph => | |||||
| { | |||||
| tf.train.import_meta_graph("graph/InceptionV3.meta"); | |||||
| Tensor resized_input_tensor = graph.OperationByName("Placeholder"); //tf.placeholder(tf.float32, new TensorShape(-1, height, width, 3)); | |||||
| // var m = hub.Module(module_spec); | |||||
| Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");// m(resized_input_tensor); | |||||
| var wants_quantization = false; | |||||
| return (graph, bottleneck_tensor, resized_input_tensor, wants_quantization); | |||||
| }); | |||||
| var graph = tf.Graph().as_default(); | |||||
| tf.train.import_meta_graph("graph/InceptionV3.meta"); | |||||
| Tensor resized_input_tensor = graph.OperationByName(input_tensor_name); //tf.placeholder(tf.float32, new TensorShape(-1, height, width, 3)); | |||||
| // var m = hub.Module(module_spec); | |||||
| Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");// m(resized_input_tensor); | |||||
| var wants_quantization = false; | |||||
| return (graph, bottleneck_tensor, resized_input_tensor, wants_quantization); | |||||
| } | } | ||||
| private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, | private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, | ||||
| @@ -594,13 +592,10 @@ namespace TensorFlowNET.Examples | |||||
| create_module_graph(); | create_module_graph(); | ||||
| // Add the new layer that we'll be training. | // Add the new layer that we'll be training. | ||||
| tf_with(graph.as_default(), delegate | |||||
| { | |||||
| (train_step, cross_entropy, bottleneck_input, | |||||
| ground_truth_input, final_tensor) = add_final_retrain_ops( | |||||
| class_count, final_tensor_name, bottleneck_tensor, | |||||
| wants_quantization, is_training: true); | |||||
| }); | |||||
| (train_step, cross_entropy, bottleneck_input, | |||||
| ground_truth_input, final_tensor) = add_final_retrain_ops( | |||||
| class_count, final_tensor_name, bottleneck_tensor, | |||||
| wants_quantization, is_training: true); | |||||
| return graph; | return graph; | ||||
| } | } | ||||
| @@ -734,15 +729,15 @@ namespace TensorFlowNET.Examples | |||||
| var labels = File.ReadAllLines(output_labels); | var labels = File.ReadAllLines(output_labels); | ||||
| // predict image | // predict image | ||||
| var img_path = Path.Join(image_dir, "roses", "12240303_80d87f77a3_n.jpg"); | |||||
| var img_path = Path.Join(image_dir, "daisy", "5547758_eea9edfd54_n.jpg"); | |||||
| var fileBytes = ReadTensorFromImageFile(img_path); | var fileBytes = ReadTensorFromImageFile(img_path); | ||||
| // import graph and variables | // import graph and variables | ||||
| var graph = new Graph(); | var graph = new Graph(); | ||||
| graph.Import(output_graph, ""); | graph.Import(output_graph, ""); | ||||
| Tensor input = graph.OperationByName("Placeholder"); | |||||
| Tensor output = graph.OperationByName("final_result"); | |||||
| Tensor input = graph.OperationByName(input_tensor_name); | |||||
| Tensor output = graph.OperationByName(final_tensor_name); | |||||
| using (var sess = tf.Session(graph)) | using (var sess = tf.Session(graph)) | ||||
| { | { | ||||
| @@ -7,12 +7,13 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestClass] | [TestClass] | ||||
| public class NameScopeTest | public class NameScopeTest | ||||
| { | { | ||||
| Graph g = ops.get_default_graph(); | |||||
| string name = ""; | string name = ""; | ||||
| [TestMethod] | [TestMethod] | ||||
| public void NestedNameScope() | public void NestedNameScope() | ||||
| { | { | ||||
| Graph g = tf.Graph().as_default(); | |||||
| tf_with(new ops.NameScope("scope1"), scope1 => | tf_with(new ops.NameScope("scope1"), scope1 => | ||||
| { | { | ||||
| name = scope1; | name = scope1; | ||||
| @@ -37,6 +38,8 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.AreEqual("scope1/Const_1:0", const3.name); | Assert.AreEqual("scope1/Const_1:0", const3.name); | ||||
| }); | }); | ||||
| g.Dispose(); | |||||
| Assert.AreEqual("", g._name_stack); | Assert.AreEqual("", g._name_stack); | ||||
| } | } | ||||
| } | } | ||||
| @@ -131,7 +131,7 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| public void logicalAndTest() | |||||
| public void logicalOpsTest() | |||||
| { | { | ||||
| var a = tf.constant(new[] {1f, 2f, 3f, 4f, -4f, -3f, -2f, -1f}); | var a = tf.constant(new[] {1f, 2f, 3f, 4f, -4f, -3f, -2f, -1f}); | ||||
| var b = tf.less(a, 0f); | var b = tf.less(a, 0f); | ||||
| @@ -144,6 +144,24 @@ namespace TensorFlowNET.UnitTest | |||||
| var o = sess.run(d); | var o = sess.run(d); | ||||
| Assert.IsTrue(o.array_equal(check)); | Assert.IsTrue(o.array_equal(check)); | ||||
| } | } | ||||
| d = tf.cast(tf.logical_not(b), tf.int32); | |||||
| check = np.array(new[] { 1, 1, 1, 1, 0, 0, 0, 0 }); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var o = sess.run(d); | |||||
| Assert.IsTrue(o.array_equal(check)); | |||||
| } | |||||
| d = tf.cast(tf.logical_or(b, c), tf.int32); | |||||
| check = np.array(new[] { 1, 1, 1, 1, 1, 1, 1, 1 }); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var o = sess.run(d); | |||||
| Assert.IsTrue(o.array_equal(check)); | |||||
| } | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||