Browse Source

release v0.40.1.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
1b7e25cd8f
2 changed files with 22 additions and 21 deletions
  1. +13
    -18
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  2. +9
    -3
      src/TensorFlowNET.Keras/Utils/data_utils.cs

+ 13
- 18
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -155,39 +155,34 @@ namespace Tensorflow


protected unsafe NDArray GetNDArray(TF_DataType dtype) protected unsafe NDArray GetNDArray(TF_DataType dtype)
{ {
UnmanagedStorage storage;
if (dtype == TF_DataType.TF_STRING)
return np.array(StringData());

var count = Convert.ToInt64(size);
IUnmanagedMemoryBlock mem;
switch (dtype) switch (dtype)
{ {
case TF_DataType.TF_BOOL: case TF_DataType.TF_BOOL:
storage = new UnmanagedStorage(NPTypeCode.Boolean);
break;
case TF_DataType.TF_STRING:
var nd = np.array(StringData());
return nd;
case TF_DataType.TF_UINT8:
storage = new UnmanagedStorage(NPTypeCode.Byte);
mem = new UnmanagedMemoryBlock<bool>((bool*)buffer, count);
break; break;
case TF_DataType.TF_INT32: case TF_DataType.TF_INT32:
storage = new UnmanagedStorage(NPTypeCode.Int32);
mem = new UnmanagedMemoryBlock<int>((int*)buffer, count);
break; break;
case TF_DataType.TF_INT64: case TF_DataType.TF_INT64:
storage = new UnmanagedStorage(NPTypeCode.Int64);
mem = new UnmanagedMemoryBlock<long>((long*)buffer, count);
break; break;
case TF_DataType.TF_FLOAT: case TF_DataType.TF_FLOAT:
storage = new UnmanagedStorage(NPTypeCode.Float);
mem = new UnmanagedMemoryBlock<float>((float*)buffer, count);
break; break;
case TF_DataType.TF_DOUBLE: case TF_DataType.TF_DOUBLE:
storage = new UnmanagedStorage(NPTypeCode.Double);
mem = new UnmanagedMemoryBlock<double>((double*)buffer, count);
break; break;
default: default:
return BufferToArray();
mem = new UnmanagedMemoryBlock<byte>((byte*)buffer, count);
break;
} }


storage.Allocate(new Shape(shape));

System.Buffer.MemoryCopy(buffer.ToPointer(), storage.Address, bytesize, bytesize);

return new NDArray(storage);
return new NDArray(ArraySlice.FromMemoryBlock(mem, copy: true), new Shape(shape));
} }


/// <summary> /// <summary>


+ 9
- 3
src/TensorFlowNET.Keras/Utils/data_utils.cs View File

@@ -18,6 +18,8 @@ namespace Tensorflow.Keras.Utils
string archive_format = "auto", string archive_format = "auto",
string cache_dir = null) string cache_dir = null)
{ {
if (string.IsNullOrEmpty(cache_dir))
cache_dir = Path.GetTempPath();
var datadir_base = cache_dir; var datadir_base = cache_dir;
Directory.CreateDirectory(datadir_base); Directory.CreateDirectory(datadir_base);


@@ -26,10 +28,14 @@ namespace Tensorflow.Keras.Utils


Web.Download(origin, datadir, fname); Web.Download(origin, datadir, fname);


var archive = Path.Combine(datadir, fname);

if (untar) if (untar)
Compress.ExtractTGZ(Path.Combine(datadir_base, fname), datadir_base);
else if (extract)
Compress.ExtractGZip(Path.Combine(datadir_base, fname), datadir_base);
Compress.ExtractTGZ(archive, datadir);
else if (extract && fname.EndsWith(".gz"))
Compress.ExtractGZip(archive, datadir);
else if (extract && fname.EndsWith(".zip"))
Compress.UnZip(archive, datadir);


return datadir; return datadir;
} }


Loading…
Cancel
Save