diff --git a/src/TensorFlowNET.Core/GlobalUsing.cs b/src/TensorFlowNET.Core/GlobalUsing.cs index fe77202c..2fd5b437 100644 --- a/src/TensorFlowNET.Core/GlobalUsing.cs +++ b/src/TensorFlowNET.Core/GlobalUsing.cs @@ -1,3 +1,6 @@ global using System; global using System.Collections.Generic; global using System.Text; +global using System.Collections; +global using System.Data; +global using System.Linq; \ No newline at end of file diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs index 70a4245b..05f53d5e 100644 --- a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs @@ -1,11 +1,4 @@ - using System; -using System.Collections; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Reflection; -using System.Text; -using Tensorflow.Util; +using System.IO; namespace Tensorflow.NumPy { @@ -15,10 +8,7 @@ namespace Tensorflow.NumPy { using var stream = new FileStream(file, FileMode.Open); using var reader = new BinaryReader(stream, Encoding.ASCII, leaveOpen: true); - int bytes; - Type type; - int[] shape; - if (!ParseReader(reader, out bytes, out type, out shape)) + if (!ParseReader(reader, out var bytes, out var type, out var shape)) throw new FormatException(); Array array = Create(type, shape.Aggregate((dims, dim) => dims * dim)); @@ -31,10 +21,7 @@ namespace Tensorflow.NumPy { using (var reader = new BinaryReader(stream, System.Text.Encoding.ASCII, leaveOpen: true)) { - int bytes; - Type type; - int[] shape; - if (!ParseReader(reader, out bytes, out type, out shape)) + if (!ParseReader(reader, out var bytes, out var type, out var shape)) throw new FormatException(); Array matrix = Array.CreateInstance(type, shape); diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Persistence.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Persistence.cs new file mode 100644 index 00000000..b349f522 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Persistence.cs @@ -0,0 +1,60 @@ +/***************************************************************************** + Copyright 2023 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.IO; +using System.IO.Compression; + +namespace Tensorflow.NumPy; + +public partial class np +{ + [AutoNumPy] + public static NpzDictionary loadz(string file) + { + using var stream = new FileStream(file, FileMode.Open); + return new NpzDictionary(stream); + } + + public static void save(string file, NDArray nd) + { + using var stream = new FileStream(file, FileMode.Create); + NpyFormat.Save(nd, stream); + } + + public static void savez(string file, params NDArray[] nds) + { + using var stream = new FileStream(file, FileMode.Create); + NpzFormat.Save(nds, stream); + } + + public static void savez(string file, object nds) + { + using var stream = new FileStream(file, FileMode.Create); + NpzFormat.Save(nds, stream); + } + + public static void savez_compressed(string file, params NDArray[] nds) + { + using var stream = new FileStream(file, FileMode.Create); + NpzFormat.Save(nds, stream, CompressionLevel.Fastest); + } + + public static void savez_compressed(string file, object nds) + { + using var stream = new FileStream(file, FileMode.Create); + NpzFormat.Save(nds, stream, CompressionLevel.Fastest); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs b/src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs new file mode 100644 index 00000000..1886e4b4 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs @@ -0,0 +1,95 @@ +using System.IO; +using System.Runtime.InteropServices; + +namespace Tensorflow.NumPy; + +public class NpyFormat +{ + public static void Save(NDArray array, Stream stream, bool leaveOpen = true) + { + using var writer = new BinaryWriter(stream, Encoding.ASCII, leaveOpen: leaveOpen); + + string dtype = GetDtypeName(array, out var type, out var maxLength); + int[] shape = array.shape.as_int_list(); + var bytesWritten = (ulong)writeHeader(writer, dtype, shape); + stream.Write(array.ToByteArray(), 0, (int)array.bytesize); + } + + private static int writeHeader(BinaryWriter writer, string dtype, int[] shape) + { + // The first 6 bytes are a magic string: exactly "x93NUMPY" + + char[] magic = { 'N', 'U', 'M', 'P', 'Y' }; + writer.Write((byte)147); + writer.Write(magic); + writer.Write((byte)1); // major + writer.Write((byte)0); // minor; + + string tuple = shape.Length == 1 ? $"{shape[0]}," : String.Join(", ", shape.Select(i => i.ToString()).ToArray()); + string header = "{{'descr': '{0}', 'fortran_order': False, 'shape': ({1}), }}"; + header = string.Format(header, dtype, tuple); + int preamble = 10; // magic string (6) + 4 + + int len = header.Length + 1; // the 1 is to account for the missing \n at the end + int headerSize = len + preamble; + + int pad = 16 - (headerSize % 16); + header = header.PadRight(header.Length + pad); + header += "\n"; + headerSize = header.Length + preamble; + + if (headerSize % 16 != 0) + throw new Exception(""); + + writer.Write((ushort)header.Length); + for (int i = 0; i < header.Length; i++) + writer.Write((byte)header[i]); + + return headerSize; + } + + private static string GetDtypeName(NDArray array, out Type type, out int bytes) + { + type = array.dtype.as_system_dtype(); + + bytes = 1; + + if (type == typeof(string)) + { + throw new NotSupportedException(""); + } + else if (type == typeof(bool)) + { + bytes = 1; + } + else + { + bytes = Marshal.SizeOf(type); + } + + if (type == typeof(bool)) + return "|b1"; + else if (type == typeof(byte)) + return "|i1"; + else if (type == typeof(short)) + return " : IDisposable, IReadOnlyDictionary, ICollection + where T : class, + ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable +{ + Stream stream; + ZipArchive archive; + + bool disposedValue = false; + + Dictionary entries; + Dictionary arrays; + + + public NpzDictionary(Stream stream) + { + this.stream = stream; + this.archive = new ZipArchive(stream, ZipArchiveMode.Read, leaveOpen: true); + + this.entries = new Dictionary(); + foreach (var entry in archive.Entries) + this.entries[entry.FullName] = entry; + + this.arrays = new Dictionary(); + } + + + public IEnumerable Keys + { + get { return entries.Keys; } + } + + + public IEnumerable Values + { + get { return entries.Values.Select(OpenEntry); } + } + + public int Count + { + get { return entries.Count; } + } + + + public object SyncRoot + { + get { return ((ICollection)entries).SyncRoot; } + } + + + public bool IsSynchronized + { + get { return ((ICollection)entries).IsSynchronized; } + } + + public bool IsReadOnly + { + get { return true; } + } + + public T this[string key] + { + get { return OpenEntry(entries[key]); } + } + + private T OpenEntry(ZipArchiveEntry entry) + { + T array; + if (arrays.TryGetValue(entry.FullName, out array)) + return array; + + using (Stream s = entry.Open()) + { + array = Load_Npz(s); + arrays[entry.FullName] = array; + return array; + } + } + + protected virtual T Load_Npz(Stream s) + { + return np.Load(s); + } + + public bool ContainsKey(string key) + { + return entries.ContainsKey(key); + } + + public bool TryGetValue(string key, out T value) + { + value = default(T); + ZipArchiveEntry entry; + if (!entries.TryGetValue(key, out entry)) + return false; + value = OpenEntry(entry); + return true; + } + + public IEnumerator> GetEnumerator() + { + foreach (var entry in archive.Entries) + yield return new KeyValuePair(entry.FullName, OpenEntry(entry)); + } + + IEnumerator IEnumerable.GetEnumerator() + { + foreach (var entry in archive.Entries) + yield return new KeyValuePair(entry.FullName, OpenEntry(entry)); + } + + IEnumerator IEnumerable.GetEnumerator() + { + foreach (var entry in archive.Entries) + yield return OpenEntry(entry); + } + + public void CopyTo(Array array, int arrayIndex) + { + foreach (var v in this) + array.SetValue(v, arrayIndex++); + } + + public void CopyTo(T[] array, int arrayIndex) + { + foreach (var v in this) + array.SetValue(v, arrayIndex++); + } + + public void Add(T item) + { + throw new ReadOnlyException(); + } + + public void Clear() + { + throw new ReadOnlyException(); + } + + public bool Contains(T item) + { + foreach (var v in this) + if (Object.Equals(v.Value, item)) + return true; + return false; + } + + public bool Remove(T item) + { + throw new ReadOnlyException(); + } + + protected virtual void Dispose(bool disposing) + { + if (!disposedValue) + { + if (disposing) + { + archive.Dispose(); + stream.Dispose(); + } + + archive = null; + stream = null; + entries = null; + arrays = null; + + disposedValue = true; + } + } + + public void Dispose() + { + Dispose(true); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs new file mode 100644 index 00000000..6e81216e --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs @@ -0,0 +1,138 @@ +using System.IO; +using System.IO.Compression; + +namespace Tensorflow.NumPy; + +public class NpzDictionary +{ + Dictionary arrays = new Dictionary(); + + public NDArray this[string key] => arrays[key]; + + public NpzDictionary(Stream stream) + { + using var archive = new ZipArchive(stream, ZipArchiveMode.Read, leaveOpen: false); + + foreach (var entry in archive.Entries) + { + arrays[entry.FullName] = OpenEntry(entry); + } + } + + private NDArray OpenEntry(ZipArchiveEntry entry) + { + if (arrays.TryGetValue(entry.FullName, out var array)) + return array; + + using var s = entry.Open(); + return LoadMatrix(s); + } + + public Array LoadMatrix(Stream stream) + { + using var reader = new BinaryReader(stream, System.Text.Encoding.ASCII, leaveOpen: false); + + if (!ParseReader(reader, out var bytes, out var type, out var shape)) + throw new FormatException(); + + Array matrix = Array.CreateInstance(type, shape); + + return ReadMatrix(reader, matrix, bytes, type, shape); + } + + bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape) + { + bytes = 0; + t = null; + shape = null; + + // The first 6 bytes are a magic string: exactly "x93NUMPY" + if (reader.ReadChar() != 63) return false; + if (reader.ReadChar() != 'N') return false; + if (reader.ReadChar() != 'U') return false; + if (reader.ReadChar() != 'M') return false; + if (reader.ReadChar() != 'P') return false; + if (reader.ReadChar() != 'Y') return false; + + byte major = reader.ReadByte(); // 1 + byte minor = reader.ReadByte(); // 0 + + if (major != 1 || minor != 0) + throw new NotSupportedException(); + + ushort len = reader.ReadUInt16(); + + string header = new string(reader.ReadChars(len)); + string mark = "'descr': '"; + int s = header.IndexOf(mark) + mark.Length; + int e = header.IndexOf("'", s + 1); + string type = header.Substring(s, e - s); + bool? isLittleEndian; + t = GetType(type, out bytes, out isLittleEndian); + + if (isLittleEndian.HasValue && isLittleEndian.Value == false) + throw new Exception(); + + mark = "'fortran_order': "; + s = header.IndexOf(mark) + mark.Length; + e = header.IndexOf(",", s + 1); + bool fortran = bool.Parse(header.Substring(s, e - s)); + + if (fortran) + throw new Exception(); + + mark = "'shape': ("; + s = header.IndexOf(mark) + mark.Length; + e = header.IndexOf(")", s + 1); + shape = header.Substring(s, e - s).Split(',').Where(v => !String.IsNullOrEmpty(v)).Select(Int32.Parse).ToArray(); + + return true; + } + + Type GetType(string dtype, out int bytes, out bool? isLittleEndian) + { + isLittleEndian = IsLittleEndian(dtype); + bytes = int.Parse(dtype.Substring(2)); + + string typeCode = dtype.Substring(1); + return typeCode switch + { + "b1" => typeof(bool), + "i1" => typeof(byte), + "i2" => typeof(short), + "i4" => typeof(int), + "i8" => typeof(long), + "u1" => typeof(byte), + "u2" => typeof(ushort), + "u4" => typeof(uint), + "u8" => typeof(ulong), + "f4" => typeof(float), + "f8" => typeof(double), + // typeCode.StartsWith("S") => typeof(string), + _ => throw new NotSupportedException() + }; + } + + bool? IsLittleEndian(string type) + { + return type[0] switch + { + '<' => true, + '>' => false, + '|' => null, + _ => throw new Exception() + }; + } + + Array ReadMatrix(BinaryReader reader, Array matrix, int bytes, Type type, int[] shape) + { + int total = 1; + for (int i = 0; i < shape.Length; i++) + total *= shape[i]; + + var buffer = reader.ReadBytes(bytes * total); + System.Buffer.BlockCopy(buffer, 0, matrix, 0, buffer.Length); + + return matrix; + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Persistence/NpzFormat.cs b/src/TensorFlowNET.Core/NumPy/Persistence/NpzFormat.cs new file mode 100644 index 00000000..7470a1ea --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Persistence/NpzFormat.cs @@ -0,0 +1,37 @@ +using System.IO.Compression; +using System.IO; +using System; + +namespace Tensorflow.NumPy; + +public class NpzFormat +{ + public static void Save(NDArray[] arrays, Stream stream, CompressionLevel compression = CompressionLevel.NoCompression, bool leaveOpen = false) + { + using var zip = new ZipArchive(stream, ZipArchiveMode.Create, leaveOpen: leaveOpen); + for (int i = 0; i < arrays.Length; i++) + { + var entry = zip.CreateEntry($"arr_{i}", compression); + NpyFormat.Save(arrays[i], entry.Open(), leaveOpen); + } + } + + public static void Save(object arrays, Stream stream, CompressionLevel compression = CompressionLevel.NoCompression, bool leaveOpen = false) + { + var properties = arrays.GetType().GetProperties(); + using var zip = new ZipArchive(stream, ZipArchiveMode.Create, leaveOpen: leaveOpen); + for (int i = 0; i < properties.Length; i++) + { + var entry = zip.CreateEntry(properties[i].Name, compression); + var value = properties[i].GetValue(arrays); + if (value is NDArray nd) + { + NpyFormat.Save(nd, entry.Open(), leaveOpen); + } + else + { + throw new NotSupportedException("Please pass in NDArray."); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Numpy/NpzDictionary.cs b/src/TensorFlowNET.Core/Numpy/NpzDictionary.cs deleted file mode 100644 index bb7ff693..00000000 --- a/src/TensorFlowNET.Core/Numpy/NpzDictionary.cs +++ /dev/null @@ -1,206 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Data; -using System.IO; -using System.IO.Compression; -using System.Linq; -using System.Text; -using Tensorflow.Util; - -namespace Tensorflow.NumPy -{ - public class NpzDictionary : IDisposable, IReadOnlyDictionary, ICollection - where T : class, - ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable - { - Stream stream; - ZipArchive archive; - - bool disposedValue = false; - - Dictionary entries; - Dictionary arrays; - - - public NpzDictionary(Stream stream) - { - this.stream = stream; - this.archive = new ZipArchive(stream, ZipArchiveMode.Read, leaveOpen: true); - - this.entries = new Dictionary(); - foreach (var entry in archive.Entries) - this.entries[entry.FullName] = entry; - - this.arrays = new Dictionary(); - } - - - public IEnumerable Keys - { - get { return entries.Keys; } - } - - - public IEnumerable Values - { - get { return entries.Values.Select(OpenEntry); } - } - - public int Count - { - get { return entries.Count; } - } - - - public object SyncRoot - { - get { return ((ICollection)entries).SyncRoot; } - } - - - public bool IsSynchronized - { - get { return ((ICollection)entries).IsSynchronized; } - } - - public bool IsReadOnly - { - get { return true; } - } - - public T this[string key] - { - get { return OpenEntry(entries[key]); } - } - - private T OpenEntry(ZipArchiveEntry entry) - { - T array; - if (arrays.TryGetValue(entry.FullName, out array)) - return array; - - using (Stream s = entry.Open()) - { - array = Load_Npz(s); - arrays[entry.FullName] = array; - return array; - } - } - - protected virtual T Load_Npz(Stream s) - { - return np.Load(s); - } - - public bool ContainsKey(string key) - { - return entries.ContainsKey(key); - } - - public bool TryGetValue(string key, out T value) - { - value = default(T); - ZipArchiveEntry entry; - if (!entries.TryGetValue(key, out entry)) - return false; - value = OpenEntry(entry); - return true; - } - - public IEnumerator> GetEnumerator() - { - foreach (var entry in archive.Entries) - yield return new KeyValuePair(entry.FullName, OpenEntry(entry)); - } - - IEnumerator IEnumerable.GetEnumerator() - { - foreach (var entry in archive.Entries) - yield return new KeyValuePair(entry.FullName, OpenEntry(entry)); - } - - IEnumerator IEnumerable.GetEnumerator() - { - foreach (var entry in archive.Entries) - yield return OpenEntry(entry); - } - - public void CopyTo(Array array, int arrayIndex) - { - foreach (var v in this) - array.SetValue(v, arrayIndex++); - } - - public void CopyTo(T[] array, int arrayIndex) - { - foreach (var v in this) - array.SetValue(v, arrayIndex++); - } - - public void Add(T item) - { - throw new ReadOnlyException(); - } - - public void Clear() - { - throw new ReadOnlyException(); - } - - public bool Contains(T item) - { - foreach (var v in this) - if (Object.Equals(v.Value, item)) - return true; - return false; - } - - public bool Remove(T item) - { - throw new ReadOnlyException(); - } - - protected virtual void Dispose(bool disposing) - { - if (!disposedValue) - { - if (disposing) - { - archive.Dispose(); - stream.Dispose(); - } - - archive = null; - stream = null; - entries = null; - arrays = null; - - disposedValue = true; - } - } - - public void Dispose() - { - Dispose(true); - } - } - - public class NpzDictionary : NpzDictionary - { - bool jagged; - - public NpzDictionary(Stream stream, bool jagged) - : base(stream) - { - this.jagged = jagged; - } - - protected override Array Load_Npz(Stream s) - { - //if (jagged) - //return np.LoadJagged(s); - return np.LoadMatrix(s); - } - } -} diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs index 9604392c..409e5e31 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs @@ -1,8 +1,4 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.IO; -using System.Text; +using System.IO; using static Tensorflow.Binding; namespace Tensorflow.NumPy @@ -65,6 +61,7 @@ namespace Tensorflow.NumPy [AutoNumPy] public static NDArray load(string file) => tf.numpy.load(file); + [AutoNumPy] public static T Load(string path) where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable { @@ -102,6 +99,7 @@ namespace Tensorflow.NumPy public static NDArray ones(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => new NDArray(tf.ones(shape, dtype: dtype)); + [AutoNumPy] public static NDArray ones_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid) => new NDArray(tf.ones_like(a, dtype: dtype)); diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.cs b/src/TensorFlowNET.Core/Numpy/Numpy.cs index cd9373d4..72d2e981 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.cs @@ -14,65 +14,58 @@ limitations under the License. ******************************************************************************/ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Numerics; -using System.Text; +namespace Tensorflow.NumPy; -namespace Tensorflow.NumPy +public partial class np { - public partial class np - { - /// - /// A convenient alias for None, useful for indexing arrays. - /// - /// https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html



https://stackoverflow.com/questions/42190783/what-does-three-dots-in-python-mean-when-indexing-what-looks-like-a-number
- public static readonly Slice newaxis = new Slice(null, null, 1) { IsNewAxis = true }; + /// + /// A convenient alias for None, useful for indexing arrays. + /// + /// https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html



https://stackoverflow.com/questions/42190783/what-does-three-dots-in-python-mean-when-indexing-what-looks-like-a-number
+ public static readonly Slice newaxis = new Slice(null, null, 1) { IsNewAxis = true }; - // https://docs.scipy.org/doc/numpy-1.16.0/user/basics.types.html - #region data type - public static readonly TF_DataType @bool = TF_DataType.TF_BOOL; - public static readonly TF_DataType @char = TF_DataType.TF_INT8; - public static readonly TF_DataType @byte = TF_DataType.TF_INT8; - public static readonly TF_DataType uint8 = TF_DataType.TF_UINT8; - public static readonly TF_DataType ubyte = TF_DataType.TF_UINT8; - public static readonly TF_DataType int16 = TF_DataType.TF_INT16; - public static readonly TF_DataType uint16 = TF_DataType.TF_UINT16; - public static readonly TF_DataType int32 = TF_DataType.TF_INT32; - public static readonly TF_DataType uint32 = TF_DataType.TF_UINT32; - public static readonly TF_DataType int64 = TF_DataType.TF_INT64; - public static readonly TF_DataType uint64 = TF_DataType.TF_UINT64; - public static readonly TF_DataType float32 = TF_DataType.TF_FLOAT; - public static readonly TF_DataType float64 = TF_DataType.TF_DOUBLE; - public static readonly TF_DataType @double = TF_DataType.TF_DOUBLE; - public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE; - public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX; - public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64; - public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128; - #endregion + // https://docs.scipy.org/doc/numpy-1.16.0/user/basics.types.html + #region data type + public static readonly TF_DataType @bool = TF_DataType.TF_BOOL; + public static readonly TF_DataType @char = TF_DataType.TF_INT8; + public static readonly TF_DataType @byte = TF_DataType.TF_INT8; + public static readonly TF_DataType uint8 = TF_DataType.TF_UINT8; + public static readonly TF_DataType ubyte = TF_DataType.TF_UINT8; + public static readonly TF_DataType int16 = TF_DataType.TF_INT16; + public static readonly TF_DataType uint16 = TF_DataType.TF_UINT16; + public static readonly TF_DataType int32 = TF_DataType.TF_INT32; + public static readonly TF_DataType uint32 = TF_DataType.TF_UINT32; + public static readonly TF_DataType int64 = TF_DataType.TF_INT64; + public static readonly TF_DataType uint64 = TF_DataType.TF_UINT64; + public static readonly TF_DataType float32 = TF_DataType.TF_FLOAT; + public static readonly TF_DataType float64 = TF_DataType.TF_DOUBLE; + public static readonly TF_DataType @double = TF_DataType.TF_DOUBLE; + public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE; + public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX; + public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64; + public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128; + #endregion - public static double nan => double.NaN; - public static double NAN => double.NaN; - public static double NaN => double.NaN; - public static double pi => Math.PI; - public static double e => Math.E; - public static double euler_gamma => 0.57721566490153286060651209008240243d; - public static double inf => double.PositiveInfinity; - public static double infty => double.PositiveInfinity; - public static double Inf => double.PositiveInfinity; - public static double NINF => double.NegativeInfinity; - public static double PINF => double.PositiveInfinity; - public static double Infinity => double.PositiveInfinity; - public static double infinity => double.PositiveInfinity; + public static double nan => double.NaN; + public static double NAN => double.NaN; + public static double NaN => double.NaN; + public static double pi => Math.PI; + public static double e => Math.E; + public static double euler_gamma => 0.57721566490153286060651209008240243d; + public static double inf => double.PositiveInfinity; + public static double infty => double.PositiveInfinity; + public static double Inf => double.PositiveInfinity; + public static double NINF => double.NegativeInfinity; + public static double PINF => double.PositiveInfinity; + public static double Infinity => double.PositiveInfinity; + public static double infinity => double.PositiveInfinity; - public static bool array_equal(NDArray a, NDArray b) - => a.Equals(b); + public static bool array_equal(NDArray a, NDArray b) + => a.Equals(b); - public static bool allclose(NDArray a, NDArray b, double rtol = 1.0E-5, double atol = 1.0E-8, - bool equal_nan = false) => throw new NotImplementedException(""); + public static bool allclose(NDArray a, NDArray b, double rtol = 1.0E-5, double atol = 1.0E-8, + bool equal_nan = false) => throw new NotImplementedException(""); - public static RandomizedImpl random = new RandomizedImpl(); - public static LinearAlgebraImpl linalg = new LinearAlgebraImpl(); - } + public static RandomizedImpl random = new RandomizedImpl(); + public static LinearAlgebraImpl linalg = new LinearAlgebraImpl(); } diff --git a/test/TensorFlowNET.UnitTest/NumPy/Persistence.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Persistence.Test.cs new file mode 100644 index 00000000..21db6acc --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Persistence.Test.cs @@ -0,0 +1,42 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy; + +/// +/// https://numpy.org/doc/stable/reference/generated/numpy.save.html +/// +[TestClass] +public class PersistenceTest : EagerModeTestBase +{ + [TestMethod] + public void SaveNpy() + { + var x = np.arange(10f).reshape((2, 5)); + np.save("arange.npy", x); + + var x2 = np.load("arange.npy"); + Assert.AreEqual(x.shape, x2.shape); + } + + [TestMethod] + public void SaveNpz() + { + var x = np.arange(10f).reshape((2, 5)); + var y = np.arange(10f).reshape((5, 2)); + + np.savez("arange.npz", x, y); + var z = np.loadz("arange.npz"); + + np.savez("arange_named.npz", new { x, y }); + z = np.loadz("arange_named.npz"); + Assert.AreEqual(z["x"].shape, x.shape); + Assert.AreEqual(z["y"].shape, y.shape); + + np.savez_compressed("arange_compressed.npz", x, y); + np.savez_compressed("arange_compressed_named.npz", new { x, y }); + z = np.loadz("arange_compressed_named.npz"); + Assert.AreEqual(z["x"].shape, x.shape); + Assert.AreEqual(z["y"].shape, y.shape); + } +}