Browse Source

Fix allocate tensor for string #171

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
2579afc84a
2 changed files with 10 additions and 18 deletions
  1. +8
    -16
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Train/Saving/Saver.cs

+ 8
- 16
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -36,6 +36,10 @@ namespace Tensorflow
size = (ulong)(nd.size * nd.dtypesize); size = (ulong)(nd.size * nd.dtypesize);
} }


var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();

switch (nd.dtype.Name) switch (nd.dtype.Name)
{ {
case "Int16": case "Int16":
@@ -51,17 +55,8 @@ namespace Tensorflow
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
break; break;
case "String": case "String":
/*var value = nd.Data<string>()[0];
var bytes = Encoding.UTF8.GetBytes(value);
dotHandle = Marshal.AllocHGlobal(bytes.Length + 1);
Marshal.Copy(bytes, 0, dotHandle, bytes.Length);
size = (ulong)bytes.Length;*/

var str = nd.Data<string>()[0]; var str = nd.Data<string>()[0];
ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length);
//dotHandle = Marshal.AllocHGlobal((int)dst_len);
//size = c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status);

var dataType1 = ToTFDataType(nd.dtype); var dataType1 = ToTFDataType(nd.dtype);
// shape // shape
var dims1 = nd.shape.Select(x => (long)x).ToArray(); var dims1 = nd.shape.Select(x => (long)x).ToArray();
@@ -69,19 +64,16 @@ namespace Tensorflow
var tfHandle1 = c_api.TF_AllocateTensor(dataType1, var tfHandle1 = c_api.TF_AllocateTensor(dataType1,
dims1, dims1,
nd.ndim, nd.ndim,
dst_len);
dst_len + sizeof(Int64));


dotHandle = c_api.TF_TensorData(tfHandle1); dotHandle = c_api.TF_TensorData(tfHandle1);
c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status);
Marshal.WriteInt64(dotHandle, 0);
c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle + sizeof(Int64), dst_len, status);
return tfHandle1; return tfHandle1;
break;
default: default:
throw new NotImplementedException("Marshal.Copy failed."); throw new NotImplementedException("Marshal.Copy failed.");
} }

var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();
// Free the original buffer and set flag // Free the original buffer and set flag
Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) => Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) =>
{ {


+ 2
- 2
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -162,12 +162,12 @@ namespace Tensorflow


if (!_is_empty) if (!_is_empty)
{ {
var model_checkpoint_path1 = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) new FeedItem(_saver_def.FilenameTensorName, checkpoint_file)
}); });
} }


throw new NotImplementedException("");
throw new NotImplementedException("Saver.save");


return model_checkpoint_path; return model_checkpoint_path;
} }


Loading…
Cancel
Save