| @@ -8,6 +8,7 @@ namespace Tensorflow.NumPy | |||||
| { | { | ||||
| public partial class NDArray | public partial class NDArray | ||||
| { | { | ||||
| protected NDArray() { } | |||||
| public NDArray(bool value) : base(value) => NewEagerTensorHandle(); | public NDArray(bool value) : base(value) => NewEagerTensorHandle(); | ||||
| public NDArray(byte value) : base(value) => NewEagerTensorHandle(); | public NDArray(byte value) : base(value) => NewEagerTensorHandle(); | ||||
| public NDArray(short value) : base(value) => NewEagerTensorHandle(); | public NDArray(short value) : base(value) => NewEagerTensorHandle(); | ||||
| @@ -57,6 +58,20 @@ namespace Tensorflow.NumPy | |||||
| _ => throw new NotImplementedException("") | _ => throw new NotImplementedException("") | ||||
| }; | }; | ||||
| /// <summary> | |||||
| /// Reuse the existing memory instead of copying it. | |||||
| /// </summary> | |||||
| /// <param name="data_ptr"></param> | |||||
| /// <param name="shape"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="deallocator"></param> | |||||
| protected void InitWithExistingMemory(IntPtr data_ptr, Shape shape, TF_DataType dtype, c_api.DeallocatorV2 deallocator) | |||||
| { | |||||
| _handle = c_api.TF_NewTensor(TF_DataType.TF_STRING, shape.dims, shape.ndim, data_ptr, (ulong)(shape.size * dtype.get_datatype_size()), deallocator, IntPtr.Zero); | |||||
| tensor_util.DangerousManuallySetTensorDType(_handle, dtype); | |||||
| NewEagerTensorHandle(); | |||||
| } | |||||
| void NewEagerTensorHandle() | void NewEagerTensorHandle() | ||||
| { | { | ||||
| if (_handle is not null) | if (_handle is not null) | ||||
| @@ -417,7 +417,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| TF_DataType.TF_DOUBLE => constant(1.0d), | TF_DataType.TF_DOUBLE => constant(1.0d), | ||||
| TF_DataType.TF_FLOAT => constant(1.0f), | TF_DataType.TF_FLOAT => constant(1.0f), | ||||
| _ => constant(1) | |||||
| _ => constant(1, dtype) | |||||
| }; | }; | ||||
| if (shape.ndim == 0) | if (shape.ndim == 0) | ||||
| @@ -71,7 +71,7 @@ namespace Tensorflow | |||||
| /// <param name="deallocator_arg"></param> | /// <param name="deallocator_arg"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, IntPtr deallocator_arg); | |||||
| public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, DeallocatorV2 deallocator, IntPtr deallocator_arg); | |||||
| public static unsafe SafeTensorHandle TF_NewTensor(byte[] data, Shape shape, TF_DataType dtype) | public static unsafe SafeTensorHandle TF_NewTensor(byte[] data, Shape shape, TF_DataType dtype) | ||||
| { | { | ||||
| @@ -147,6 +147,15 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TF_DataType TF_TensorType(SafeTensorHandle tensor); | public static extern TF_DataType TF_TensorType(SafeTensorHandle tensor); | ||||
| /// <summary> | |||||
| /// Set a new shape for the Tensor. Note that this API only works after tf2.11. | |||||
| /// </summary> | |||||
| /// <param name="tensor"></param> | |||||
| /// <param name="dims"></param> | |||||
| /// <param name="num_dims"></param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_SetShape(SafeTensorHandle tensor, long[] dims, int num_dims); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return the size in bytes required to encode a string `len` bytes long into a | /// Return the size in bytes required to encode a string `len` bytes long into a | ||||
| /// TF_STRING tensor. | /// TF_STRING tensor. | ||||
| @@ -22,6 +22,7 @@ using System.Text; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using System.Diagnostics; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -649,5 +650,24 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
| NewAxisMask = new_axis_mask | NewAxisMask = new_axis_mask | ||||
| }; | }; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Warning: this method is an extremely dangerous method. It directly changes the dtype inside the tensor | |||||
| /// and security is not guaranteed at all. Currently this method is only used for some conditions to reuse | |||||
| /// the existing memory. Any other usage should be prevented. If you are sure you want to use it when | |||||
| /// developing tensorflow.net, please ask @Oceanic2018 or @AsakusaRinne first. | |||||
| /// </summary> | |||||
| /// <param name="handle"></param> | |||||
| /// <param name="dtype"></param> | |||||
| internal static unsafe void DangerousManuallySetTensorDType(SafeTensorHandle handle, TF_DataType dtype) | |||||
| { | |||||
| long tf_tensor_address = handle.DangerousGetHandle().ToInt64(); | |||||
| long interface_address = *(long*)(tf_tensor_address); | |||||
| long tensor_shape_address = interface_address + 8; | |||||
| long tensor_dtype_address = tensor_shape_address + 13; | |||||
| byte* dtype_pointer = (byte*)tensor_dtype_address; | |||||
| *dtype_pointer = (byte)dtype; | |||||
| Debug.Assert(c_api.TF_TensorType(handle) == dtype); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||