diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index dbc901f2..b8d9a625 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -155,39 +155,34 @@ namespace Tensorflow protected unsafe NDArray GetNDArray(TF_DataType dtype) { - UnmanagedStorage storage; + if (dtype == TF_DataType.TF_STRING) + return np.array(StringData()); + + var count = Convert.ToInt64(size); + IUnmanagedMemoryBlock mem; switch (dtype) { case TF_DataType.TF_BOOL: - storage = new UnmanagedStorage(NPTypeCode.Boolean); - break; - case TF_DataType.TF_STRING: - var nd = np.array(StringData()); - return nd; - case TF_DataType.TF_UINT8: - storage = new UnmanagedStorage(NPTypeCode.Byte); + mem = new UnmanagedMemoryBlock((bool*)buffer, count); break; case TF_DataType.TF_INT32: - storage = new UnmanagedStorage(NPTypeCode.Int32); + mem = new UnmanagedMemoryBlock((int*)buffer, count); break; case TF_DataType.TF_INT64: - storage = new UnmanagedStorage(NPTypeCode.Int64); + mem = new UnmanagedMemoryBlock((long*)buffer, count); break; case TF_DataType.TF_FLOAT: - storage = new UnmanagedStorage(NPTypeCode.Float); + mem = new UnmanagedMemoryBlock((float*)buffer, count); break; case TF_DataType.TF_DOUBLE: - storage = new UnmanagedStorage(NPTypeCode.Double); + mem = new UnmanagedMemoryBlock((double*)buffer, count); break; default: - return BufferToArray(); + mem = new UnmanagedMemoryBlock((byte*)buffer, count); + break; } - storage.Allocate(new Shape(shape)); - - System.Buffer.MemoryCopy(buffer.ToPointer(), storage.Address, bytesize, bytesize); - - return new NDArray(storage); + return new NDArray(ArraySlice.FromMemoryBlock(mem, copy: true), new Shape(shape)); } /// diff --git a/src/TensorFlowNET.Keras/Utils/data_utils.cs b/src/TensorFlowNET.Keras/Utils/data_utils.cs index fda3a545..5b84c601 100644 --- a/src/TensorFlowNET.Keras/Utils/data_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/data_utils.cs @@ -18,6 +18,8 @@ namespace Tensorflow.Keras.Utils string archive_format = "auto", string cache_dir = null) { + if (string.IsNullOrEmpty(cache_dir)) + cache_dir = Path.GetTempPath(); var datadir_base = cache_dir; Directory.CreateDirectory(datadir_base); @@ -26,10 +28,14 @@ namespace Tensorflow.Keras.Utils Web.Download(origin, datadir, fname); + var archive = Path.Combine(datadir, fname); + if (untar) - Compress.ExtractTGZ(Path.Combine(datadir_base, fname), datadir_base); - else if (extract) - Compress.ExtractGZip(Path.Combine(datadir_base, fname), datadir_base); + Compress.ExtractTGZ(archive, datadir); + else if (extract && fname.EndsWith(".gz")) + Compress.ExtractGZip(archive, datadir); + else if (extract && fname.EndsWith(".zip")) + Compress.UnZip(archive, datadir); return datadir; }