Browse Source

Support save and load in numpy.lib.format.

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
55cc4d0b78
11 changed files with 608 additions and 281 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/GlobalUsing.cs
  2. +3
    -16
      src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs
  3. +60
    -0
      src/TensorFlowNET.Core/NumPy/Numpy.Persistence.cs
  4. +95
    -0
      src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs
  5. +180
    -0
      src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionary.cs
  6. +138
    -0
      src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs
  7. +37
    -0
      src/TensorFlowNET.Core/NumPy/Persistence/NpzFormat.cs
  8. +0
    -206
      src/TensorFlowNET.Core/Numpy/NpzDictionary.cs
  9. +3
    -5
      src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs
  10. +47
    -54
      src/TensorFlowNET.Core/Numpy/Numpy.cs
  11. +42
    -0
      test/TensorFlowNET.UnitTest/NumPy/Persistence.Test.cs

+ 3
- 0
src/TensorFlowNET.Core/GlobalUsing.cs View File

@@ -1,3 +1,6 @@
global using System; global using System;
global using System.Collections.Generic; global using System.Collections.Generic;
global using System.Text; global using System.Text;
global using System.Collections;
global using System.Data;
global using System.Linq;

+ 3
- 16
src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs View File

@@ -1,11 +1,4 @@
 using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text;
using Tensorflow.Util;
using System.IO;


namespace Tensorflow.NumPy namespace Tensorflow.NumPy
{ {
@@ -15,10 +8,7 @@ namespace Tensorflow.NumPy
{ {
using var stream = new FileStream(file, FileMode.Open); using var stream = new FileStream(file, FileMode.Open);
using var reader = new BinaryReader(stream, Encoding.ASCII, leaveOpen: true); using var reader = new BinaryReader(stream, Encoding.ASCII, leaveOpen: true);
int bytes;
Type type;
int[] shape;
if (!ParseReader(reader, out bytes, out type, out shape))
if (!ParseReader(reader, out var bytes, out var type, out var shape))
throw new FormatException(); throw new FormatException();


Array array = Create(type, shape.Aggregate((dims, dim) => dims * dim)); Array array = Create(type, shape.Aggregate((dims, dim) => dims * dim));
@@ -31,10 +21,7 @@ namespace Tensorflow.NumPy
{ {
using (var reader = new BinaryReader(stream, System.Text.Encoding.ASCII, leaveOpen: true)) using (var reader = new BinaryReader(stream, System.Text.Encoding.ASCII, leaveOpen: true))
{ {
int bytes;
Type type;
int[] shape;
if (!ParseReader(reader, out bytes, out type, out shape))
if (!ParseReader(reader, out var bytes, out var type, out var shape))
throw new FormatException(); throw new FormatException();


Array matrix = Array.CreateInstance(type, shape); Array matrix = Array.CreateInstance(type, shape);


+ 60
- 0
src/TensorFlowNET.Core/NumPy/Numpy.Persistence.cs View File

@@ -0,0 +1,60 @@
/*****************************************************************************
Copyright 2023 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System.IO;
using System.IO.Compression;

namespace Tensorflow.NumPy;

public partial class np
{
[AutoNumPy]
public static NpzDictionary loadz(string file)
{
using var stream = new FileStream(file, FileMode.Open);
return new NpzDictionary(stream);
}

public static void save(string file, NDArray nd)
{
using var stream = new FileStream(file, FileMode.Create);
NpyFormat.Save(nd, stream);
}

public static void savez(string file, params NDArray[] nds)
{
using var stream = new FileStream(file, FileMode.Create);
NpzFormat.Save(nds, stream);
}

public static void savez(string file, object nds)
{
using var stream = new FileStream(file, FileMode.Create);
NpzFormat.Save(nds, stream);
}

public static void savez_compressed(string file, params NDArray[] nds)
{
using var stream = new FileStream(file, FileMode.Create);
NpzFormat.Save(nds, stream, CompressionLevel.Fastest);
}

public static void savez_compressed(string file, object nds)
{
using var stream = new FileStream(file, FileMode.Create);
NpzFormat.Save(nds, stream, CompressionLevel.Fastest);
}
}

+ 95
- 0
src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs View File

@@ -0,0 +1,95 @@
using System.IO;
using System.Runtime.InteropServices;

namespace Tensorflow.NumPy;

public class NpyFormat
{
public static void Save(NDArray array, Stream stream, bool leaveOpen = true)
{
using var writer = new BinaryWriter(stream, Encoding.ASCII, leaveOpen: leaveOpen);

string dtype = GetDtypeName(array, out var type, out var maxLength);
int[] shape = array.shape.as_int_list();
var bytesWritten = (ulong)writeHeader(writer, dtype, shape);
stream.Write(array.ToByteArray(), 0, (int)array.bytesize);
}

private static int writeHeader(BinaryWriter writer, string dtype, int[] shape)
{
// The first 6 bytes are a magic string: exactly "x93NUMPY"

char[] magic = { 'N', 'U', 'M', 'P', 'Y' };
writer.Write((byte)147);
writer.Write(magic);
writer.Write((byte)1); // major
writer.Write((byte)0); // minor;

string tuple = shape.Length == 1 ? $"{shape[0]}," : String.Join(", ", shape.Select(i => i.ToString()).ToArray());
string header = "{{'descr': '{0}', 'fortran_order': False, 'shape': ({1}), }}";
header = string.Format(header, dtype, tuple);
int preamble = 10; // magic string (6) + 4

int len = header.Length + 1; // the 1 is to account for the missing \n at the end
int headerSize = len + preamble;

int pad = 16 - (headerSize % 16);
header = header.PadRight(header.Length + pad);
header += "\n";
headerSize = header.Length + preamble;

if (headerSize % 16 != 0)
throw new Exception("");

writer.Write((ushort)header.Length);
for (int i = 0; i < header.Length; i++)
writer.Write((byte)header[i]);

return headerSize;
}

private static string GetDtypeName(NDArray array, out Type type, out int bytes)
{
type = array.dtype.as_system_dtype();

bytes = 1;

if (type == typeof(string))
{
throw new NotSupportedException("");
}
else if (type == typeof(bool))
{
bytes = 1;
}
else
{
bytes = Marshal.SizeOf(type);
}

if (type == typeof(bool))
return "|b1";
else if (type == typeof(byte))
return "|i1";
else if (type == typeof(short))
return "<i2";
else if (type == typeof(int))
return "<i4";
else if (type == typeof(long))
return "<i8";
else if (type == typeof(ushort))
return "<u2";
else if (type == typeof(uint))
return "<u4";
else if (type == typeof(ulong))
return "<u8";
else if (type == typeof(float))
return "<f4";
else if (type == typeof(double))
return "<f8";
else if (type == typeof(string))
return "|S" + bytes;
else
throw new NotSupportedException();
}
}

+ 180
- 0
src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionary.cs View File

@@ -0,0 +1,180 @@
using System.IO;
using System.IO.Compression;

namespace Tensorflow.NumPy;

public class NpzDictionary<T> : IDisposable, IReadOnlyDictionary<string, T>, ICollection<T>
where T : class,
ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
{
Stream stream;
ZipArchive archive;

bool disposedValue = false;

Dictionary<string, ZipArchiveEntry> entries;
Dictionary<string, T> arrays;


public NpzDictionary(Stream stream)
{
this.stream = stream;
this.archive = new ZipArchive(stream, ZipArchiveMode.Read, leaveOpen: true);

this.entries = new Dictionary<string, ZipArchiveEntry>();
foreach (var entry in archive.Entries)
this.entries[entry.FullName] = entry;

this.arrays = new Dictionary<string, T>();
}


public IEnumerable<string> Keys
{
get { return entries.Keys; }
}


public IEnumerable<T> Values
{
get { return entries.Values.Select(OpenEntry); }
}

public int Count
{
get { return entries.Count; }
}


public object SyncRoot
{
get { return ((ICollection)entries).SyncRoot; }
}


public bool IsSynchronized
{
get { return ((ICollection)entries).IsSynchronized; }
}

public bool IsReadOnly
{
get { return true; }
}

public T this[string key]
{
get { return OpenEntry(entries[key]); }
}

private T OpenEntry(ZipArchiveEntry entry)
{
T array;
if (arrays.TryGetValue(entry.FullName, out array))
return array;

using (Stream s = entry.Open())
{
array = Load_Npz(s);
arrays[entry.FullName] = array;
return array;
}
}

protected virtual T Load_Npz(Stream s)
{
return np.Load<T>(s);
}

public bool ContainsKey(string key)
{
return entries.ContainsKey(key);
}

public bool TryGetValue(string key, out T value)
{
value = default(T);
ZipArchiveEntry entry;
if (!entries.TryGetValue(key, out entry))
return false;
value = OpenEntry(entry);
return true;
}

public IEnumerator<KeyValuePair<string, T>> GetEnumerator()
{
foreach (var entry in archive.Entries)
yield return new KeyValuePair<string, T>(entry.FullName, OpenEntry(entry));
}

IEnumerator IEnumerable.GetEnumerator()
{
foreach (var entry in archive.Entries)
yield return new KeyValuePair<string, T>(entry.FullName, OpenEntry(entry));
}

IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
foreach (var entry in archive.Entries)
yield return OpenEntry(entry);
}

public void CopyTo(Array array, int arrayIndex)
{
foreach (var v in this)
array.SetValue(v, arrayIndex++);
}

public void CopyTo(T[] array, int arrayIndex)
{
foreach (var v in this)
array.SetValue(v, arrayIndex++);
}

public void Add(T item)
{
throw new ReadOnlyException();
}

public void Clear()
{
throw new ReadOnlyException();
}

public bool Contains(T item)
{
foreach (var v in this)
if (Object.Equals(v.Value, item))
return true;
return false;
}

public bool Remove(T item)
{
throw new ReadOnlyException();
}

protected virtual void Dispose(bool disposing)
{
if (!disposedValue)
{
if (disposing)
{
archive.Dispose();
stream.Dispose();
}

archive = null;
stream = null;
entries = null;
arrays = null;

disposedValue = true;
}
}

public void Dispose()
{
Dispose(true);
}
}

+ 138
- 0
src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs View File

@@ -0,0 +1,138 @@
using System.IO;
using System.IO.Compression;

namespace Tensorflow.NumPy;

public class NpzDictionary
{
Dictionary<string, NDArray> arrays = new Dictionary<string, NDArray>();

public NDArray this[string key] => arrays[key];

public NpzDictionary(Stream stream)
{
using var archive = new ZipArchive(stream, ZipArchiveMode.Read, leaveOpen: false);

foreach (var entry in archive.Entries)
{
arrays[entry.FullName] = OpenEntry(entry);
}
}

private NDArray OpenEntry(ZipArchiveEntry entry)
{
if (arrays.TryGetValue(entry.FullName, out var array))
return array;

using var s = entry.Open();
return LoadMatrix(s);
}

public Array LoadMatrix(Stream stream)
{
using var reader = new BinaryReader(stream, System.Text.Encoding.ASCII, leaveOpen: false);

if (!ParseReader(reader, out var bytes, out var type, out var shape))
throw new FormatException();

Array matrix = Array.CreateInstance(type, shape);

return ReadMatrix(reader, matrix, bytes, type, shape);
}

bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape)
{
bytes = 0;
t = null;
shape = null;

// The first 6 bytes are a magic string: exactly "x93NUMPY"
if (reader.ReadChar() != 63) return false;
if (reader.ReadChar() != 'N') return false;
if (reader.ReadChar() != 'U') return false;
if (reader.ReadChar() != 'M') return false;
if (reader.ReadChar() != 'P') return false;
if (reader.ReadChar() != 'Y') return false;

byte major = reader.ReadByte(); // 1
byte minor = reader.ReadByte(); // 0

if (major != 1 || minor != 0)
throw new NotSupportedException();

ushort len = reader.ReadUInt16();

string header = new string(reader.ReadChars(len));
string mark = "'descr': '";
int s = header.IndexOf(mark) + mark.Length;
int e = header.IndexOf("'", s + 1);
string type = header.Substring(s, e - s);
bool? isLittleEndian;
t = GetType(type, out bytes, out isLittleEndian);

if (isLittleEndian.HasValue && isLittleEndian.Value == false)
throw new Exception();

mark = "'fortran_order': ";
s = header.IndexOf(mark) + mark.Length;
e = header.IndexOf(",", s + 1);
bool fortran = bool.Parse(header.Substring(s, e - s));

if (fortran)
throw new Exception();

mark = "'shape': (";
s = header.IndexOf(mark) + mark.Length;
e = header.IndexOf(")", s + 1);
shape = header.Substring(s, e - s).Split(',').Where(v => !String.IsNullOrEmpty(v)).Select(Int32.Parse).ToArray();

return true;
}

Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
{
isLittleEndian = IsLittleEndian(dtype);
bytes = int.Parse(dtype.Substring(2));

string typeCode = dtype.Substring(1);
return typeCode switch
{
"b1" => typeof(bool),
"i1" => typeof(byte),
"i2" => typeof(short),
"i4" => typeof(int),
"i8" => typeof(long),
"u1" => typeof(byte),
"u2" => typeof(ushort),
"u4" => typeof(uint),
"u8" => typeof(ulong),
"f4" => typeof(float),
"f8" => typeof(double),
// typeCode.StartsWith("S") => typeof(string),
_ => throw new NotSupportedException()
};
}

bool? IsLittleEndian(string type)
{
return type[0] switch
{
'<' => true,
'>' => false,
'|' => null,
_ => throw new Exception()
};
}

Array ReadMatrix(BinaryReader reader, Array matrix, int bytes, Type type, int[] shape)
{
int total = 1;
for (int i = 0; i < shape.Length; i++)
total *= shape[i];

var buffer = reader.ReadBytes(bytes * total);
System.Buffer.BlockCopy(buffer, 0, matrix, 0, buffer.Length);

return matrix;
}
}

+ 37
- 0
src/TensorFlowNET.Core/NumPy/Persistence/NpzFormat.cs View File

@@ -0,0 +1,37 @@
using System.IO.Compression;
using System.IO;
using System;

namespace Tensorflow.NumPy;

public class NpzFormat
{
public static void Save(NDArray[] arrays, Stream stream, CompressionLevel compression = CompressionLevel.NoCompression, bool leaveOpen = false)
{
using var zip = new ZipArchive(stream, ZipArchiveMode.Create, leaveOpen: leaveOpen);
for (int i = 0; i < arrays.Length; i++)
{
var entry = zip.CreateEntry($"arr_{i}", compression);
NpyFormat.Save(arrays[i], entry.Open(), leaveOpen);
}
}

public static void Save(object arrays, Stream stream, CompressionLevel compression = CompressionLevel.NoCompression, bool leaveOpen = false)
{
var properties = arrays.GetType().GetProperties();
using var zip = new ZipArchive(stream, ZipArchiveMode.Create, leaveOpen: leaveOpen);
for (int i = 0; i < properties.Length; i++)
{
var entry = zip.CreateEntry(properties[i].Name, compression);
var value = properties[i].GetValue(arrays);
if (value is NDArray nd)
{
NpyFormat.Save(nd, entry.Open(), leaveOpen);
}
else
{
throw new NotSupportedException("Please pass in NDArray.");
}
}
}
}

+ 0
- 206
src/TensorFlowNET.Core/Numpy/NpzDictionary.cs View File

@@ -1,206 +0,0 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Text;
using Tensorflow.Util;

namespace Tensorflow.NumPy
{
public class NpzDictionary<T> : IDisposable, IReadOnlyDictionary<string, T>, ICollection<T>
where T : class,
ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
{
Stream stream;
ZipArchive archive;

bool disposedValue = false;

Dictionary<string, ZipArchiveEntry> entries;
Dictionary<string, T> arrays;


public NpzDictionary(Stream stream)
{
this.stream = stream;
this.archive = new ZipArchive(stream, ZipArchiveMode.Read, leaveOpen: true);

this.entries = new Dictionary<string, ZipArchiveEntry>();
foreach (var entry in archive.Entries)
this.entries[entry.FullName] = entry;

this.arrays = new Dictionary<string, T>();
}


public IEnumerable<string> Keys
{
get { return entries.Keys; }
}


public IEnumerable<T> Values
{
get { return entries.Values.Select(OpenEntry); }
}

public int Count
{
get { return entries.Count; }
}


public object SyncRoot
{
get { return ((ICollection)entries).SyncRoot; }
}


public bool IsSynchronized
{
get { return ((ICollection)entries).IsSynchronized; }
}

public bool IsReadOnly
{
get { return true; }
}

public T this[string key]
{
get { return OpenEntry(entries[key]); }
}

private T OpenEntry(ZipArchiveEntry entry)
{
T array;
if (arrays.TryGetValue(entry.FullName, out array))
return array;

using (Stream s = entry.Open())
{
array = Load_Npz(s);
arrays[entry.FullName] = array;
return array;
}
}

protected virtual T Load_Npz(Stream s)
{
return np.Load<T>(s);
}

public bool ContainsKey(string key)
{
return entries.ContainsKey(key);
}

public bool TryGetValue(string key, out T value)
{
value = default(T);
ZipArchiveEntry entry;
if (!entries.TryGetValue(key, out entry))
return false;
value = OpenEntry(entry);
return true;
}

public IEnumerator<KeyValuePair<string, T>> GetEnumerator()
{
foreach (var entry in archive.Entries)
yield return new KeyValuePair<string, T>(entry.FullName, OpenEntry(entry));
}

IEnumerator IEnumerable.GetEnumerator()
{
foreach (var entry in archive.Entries)
yield return new KeyValuePair<string, T>(entry.FullName, OpenEntry(entry));
}

IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
foreach (var entry in archive.Entries)
yield return OpenEntry(entry);
}

public void CopyTo(Array array, int arrayIndex)
{
foreach (var v in this)
array.SetValue(v, arrayIndex++);
}

public void CopyTo(T[] array, int arrayIndex)
{
foreach (var v in this)
array.SetValue(v, arrayIndex++);
}

public void Add(T item)
{
throw new ReadOnlyException();
}

public void Clear()
{
throw new ReadOnlyException();
}

public bool Contains(T item)
{
foreach (var v in this)
if (Object.Equals(v.Value, item))
return true;
return false;
}

public bool Remove(T item)
{
throw new ReadOnlyException();
}

protected virtual void Dispose(bool disposing)
{
if (!disposedValue)
{
if (disposing)
{
archive.Dispose();
stream.Dispose();
}

archive = null;
stream = null;
entries = null;
arrays = null;

disposedValue = true;
}
}

public void Dispose()
{
Dispose(true);
}
}

public class NpzDictionary : NpzDictionary<Array>
{
bool jagged;

public NpzDictionary(Stream stream, bool jagged)
: base(stream)
{
this.jagged = jagged;
}

protected override Array Load_Npz(Stream s)
{
//if (jagged)
//return np.LoadJagged(s);
return np.LoadMatrix(s);
}
}
}

+ 3
- 5
src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs View File

@@ -1,8 +1,4 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.IO;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.NumPy namespace Tensorflow.NumPy
@@ -65,6 +61,7 @@ namespace Tensorflow.NumPy
[AutoNumPy] [AutoNumPy]
public static NDArray load(string file) => tf.numpy.load(file); public static NDArray load(string file) => tf.numpy.load(file);


[AutoNumPy]
public static T Load<T>(string path) public static T Load<T>(string path)
where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
{ {
@@ -102,6 +99,7 @@ namespace Tensorflow.NumPy
public static NDArray ones(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) public static NDArray ones(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
=> new NDArray(tf.ones(shape, dtype: dtype)); => new NDArray(tf.ones(shape, dtype: dtype));


[AutoNumPy]
public static NDArray ones_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid) public static NDArray ones_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid)
=> new NDArray(tf.ones_like(a, dtype: dtype)); => new NDArray(tf.ones_like(a, dtype: dtype));




+ 47
- 54
src/TensorFlowNET.Core/Numpy/Numpy.cs View File

@@ -14,65 +14,58 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System;
using System.Collections;
using System.Collections.Generic;
using System.Numerics;
using System.Text;
namespace Tensorflow.NumPy;


namespace Tensorflow.NumPy
public partial class np
{ {
public partial class np
{
/// <summary>
/// A convenient alias for None, useful for indexing arrays.
/// </summary>
/// <remarks>https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html<br></br><br></br>https://stackoverflow.com/questions/42190783/what-does-three-dots-in-python-mean-when-indexing-what-looks-like-a-number</remarks>
public static readonly Slice newaxis = new Slice(null, null, 1) { IsNewAxis = true };
/// <summary>
/// A convenient alias for None, useful for indexing arrays.
/// </summary>
/// <remarks>https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html<br></br><br></br>https://stackoverflow.com/questions/42190783/what-does-three-dots-in-python-mean-when-indexing-what-looks-like-a-number</remarks>
public static readonly Slice newaxis = new Slice(null, null, 1) { IsNewAxis = true };


// https://docs.scipy.org/doc/numpy-1.16.0/user/basics.types.html
#region data type
public static readonly TF_DataType @bool = TF_DataType.TF_BOOL;
public static readonly TF_DataType @char = TF_DataType.TF_INT8;
public static readonly TF_DataType @byte = TF_DataType.TF_INT8;
public static readonly TF_DataType uint8 = TF_DataType.TF_UINT8;
public static readonly TF_DataType ubyte = TF_DataType.TF_UINT8;
public static readonly TF_DataType int16 = TF_DataType.TF_INT16;
public static readonly TF_DataType uint16 = TF_DataType.TF_UINT16;
public static readonly TF_DataType int32 = TF_DataType.TF_INT32;
public static readonly TF_DataType uint32 = TF_DataType.TF_UINT32;
public static readonly TF_DataType int64 = TF_DataType.TF_INT64;
public static readonly TF_DataType uint64 = TF_DataType.TF_UINT64;
public static readonly TF_DataType float32 = TF_DataType.TF_FLOAT;
public static readonly TF_DataType float64 = TF_DataType.TF_DOUBLE;
public static readonly TF_DataType @double = TF_DataType.TF_DOUBLE;
public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE;
public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX;
public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64;
public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
#endregion
// https://docs.scipy.org/doc/numpy-1.16.0/user/basics.types.html
#region data type
public static readonly TF_DataType @bool = TF_DataType.TF_BOOL;
public static readonly TF_DataType @char = TF_DataType.TF_INT8;
public static readonly TF_DataType @byte = TF_DataType.TF_INT8;
public static readonly TF_DataType uint8 = TF_DataType.TF_UINT8;
public static readonly TF_DataType ubyte = TF_DataType.TF_UINT8;
public static readonly TF_DataType int16 = TF_DataType.TF_INT16;
public static readonly TF_DataType uint16 = TF_DataType.TF_UINT16;
public static readonly TF_DataType int32 = TF_DataType.TF_INT32;
public static readonly TF_DataType uint32 = TF_DataType.TF_UINT32;
public static readonly TF_DataType int64 = TF_DataType.TF_INT64;
public static readonly TF_DataType uint64 = TF_DataType.TF_UINT64;
public static readonly TF_DataType float32 = TF_DataType.TF_FLOAT;
public static readonly TF_DataType float64 = TF_DataType.TF_DOUBLE;
public static readonly TF_DataType @double = TF_DataType.TF_DOUBLE;
public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE;
public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX;
public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64;
public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
#endregion


public static double nan => double.NaN;
public static double NAN => double.NaN;
public static double NaN => double.NaN;
public static double pi => Math.PI;
public static double e => Math.E;
public static double euler_gamma => 0.57721566490153286060651209008240243d;
public static double inf => double.PositiveInfinity;
public static double infty => double.PositiveInfinity;
public static double Inf => double.PositiveInfinity;
public static double NINF => double.NegativeInfinity;
public static double PINF => double.PositiveInfinity;
public static double Infinity => double.PositiveInfinity;
public static double infinity => double.PositiveInfinity;
public static double nan => double.NaN;
public static double NAN => double.NaN;
public static double NaN => double.NaN;
public static double pi => Math.PI;
public static double e => Math.E;
public static double euler_gamma => 0.57721566490153286060651209008240243d;
public static double inf => double.PositiveInfinity;
public static double infty => double.PositiveInfinity;
public static double Inf => double.PositiveInfinity;
public static double NINF => double.NegativeInfinity;
public static double PINF => double.PositiveInfinity;
public static double Infinity => double.PositiveInfinity;
public static double infinity => double.PositiveInfinity;


public static bool array_equal(NDArray a, NDArray b)
=> a.Equals(b);
public static bool array_equal(NDArray a, NDArray b)
=> a.Equals(b);


public static bool allclose(NDArray a, NDArray b, double rtol = 1.0E-5, double atol = 1.0E-8,
bool equal_nan = false) => throw new NotImplementedException("");
public static bool allclose(NDArray a, NDArray b, double rtol = 1.0E-5, double atol = 1.0E-8,
bool equal_nan = false) => throw new NotImplementedException("");


public static RandomizedImpl random = new RandomizedImpl();
public static LinearAlgebraImpl linalg = new LinearAlgebraImpl();
}
public static RandomizedImpl random = new RandomizedImpl();
public static LinearAlgebraImpl linalg = new LinearAlgebraImpl();
} }

+ 42
- 0
test/TensorFlowNET.UnitTest/NumPy/Persistence.Test.cs View File

@@ -0,0 +1,42 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;

namespace TensorFlowNET.UnitTest.NumPy;

/// <summary>
/// https://numpy.org/doc/stable/reference/generated/numpy.save.html
/// </summary>
[TestClass]
public class PersistenceTest : EagerModeTestBase
{
[TestMethod]
public void SaveNpy()
{
var x = np.arange(10f).reshape((2, 5));
np.save("arange.npy", x);

var x2 = np.load("arange.npy");
Assert.AreEqual(x.shape, x2.shape);
}

[TestMethod]
public void SaveNpz()
{
var x = np.arange(10f).reshape((2, 5));
var y = np.arange(10f).reshape((5, 2));

np.savez("arange.npz", x, y);
var z = np.loadz("arange.npz");

np.savez("arange_named.npz", new { x, y });
z = np.loadz("arange_named.npz");
Assert.AreEqual(z["x"].shape, x.shape);
Assert.AreEqual(z["y"].shape, y.shape);

np.savez_compressed("arange_compressed.npz", x, y);
np.savez_compressed("arange_compressed_named.npz", new { x, y });
z = np.loadz("arange_compressed_named.npz");
Assert.AreEqual(z["x"].shape, x.shape);
Assert.AreEqual(z["y"].shape, y.shape);
}
}

Loading…
Cancel
Save