| @@ -213,7 +213,7 @@ namespace Tensorflow.Eager | |||||
| if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) | if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) | ||||
| { | { | ||||
| var dtype = c_api.TFE_TensorHandleDataType(tensor.EagerTensorHandle); | |||||
| var dtype = tensor.dtype; | |||||
| c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); | c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); | ||||
| flattened_attrs.Add(input_arg.TypeAttr); | flattened_attrs.Add(input_arg.TypeAttr); | ||||
| flattened_attrs.Add(dtype); | flattened_attrs.Add(dtype); | ||||
| @@ -7,16 +7,10 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| public partial class EagerTensor | public partial class EagerTensor | ||||
| { | { | ||||
| public EagerTensor(SafeTensorHandle handle) | |||||
| { | |||||
| NewEagerTensorHandle(handle); | |||||
| } | |||||
| public EagerTensor(SafeEagerTensorHandle handle) | public EagerTensor(SafeEagerTensorHandle handle) | ||||
| { | { | ||||
| _id = ops.uid(); | _id = ops.uid(); | ||||
| _eagerTensorHandle = handle; | _eagerTensorHandle = handle; | ||||
| Resolve(); | |||||
| } | } | ||||
| #region scalar eager tensor | #region scalar eager tensor | ||||
| @@ -67,8 +61,10 @@ namespace Tensorflow.Eager | |||||
| tf.Status.Check(true); | tf.Status.Check(true); | ||||
| } | } | ||||
| private void Resolve() | |||||
| public void Resolve() | |||||
| { | { | ||||
| if (_handle != null) | |||||
| return; | |||||
| _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); | _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); | ||||
| tf.Status.Check(true); | tf.Status.Check(true); | ||||
| } | } | ||||
| @@ -6,11 +6,47 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| public partial class EagerTensor : Tensor | public partial class EagerTensor : Tensor | ||||
| { | { | ||||
| public override SafeTensorHandle Handle | |||||
| { | |||||
| get | |||||
| { | |||||
| Resolve(); | |||||
| return _handle; | |||||
| } | |||||
| } | |||||
| public override IntPtr buffer | |||||
| { | |||||
| get | |||||
| { | |||||
| Resolve(); | |||||
| return base.buffer; | |||||
| } | |||||
| } | |||||
| public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); | public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); | ||||
| public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); | public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); | ||||
| public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); | public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); | ||||
| public override ulong bytesize | |||||
| { | |||||
| get | |||||
| { | |||||
| Resolve(); | |||||
| return base.bytesize; | |||||
| } | |||||
| } | |||||
| public override IntPtr TensorDataPointer | |||||
| { | |||||
| get | |||||
| { | |||||
| Resolve(); | |||||
| return base.TensorDataPointer; | |||||
| } | |||||
| } | |||||
| protected override Shape GetShapeInternal() | protected override Shape GetShapeInternal() | ||||
| { | { | ||||
| var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; | var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; | ||||
| @@ -19,6 +55,12 @@ namespace Tensorflow.Eager | |||||
| return dims; | return dims; | ||||
| } | } | ||||
| protected override void SetShapeInternal(Shape value) | |||||
| { | |||||
| if (!shape.is_compatible_with(value)) | |||||
| throw new ValueError($"Tensor's shape is not compatible."); | |||||
| } | |||||
| public static int GetRank(IntPtr handle) | public static int GetRank(IntPtr handle) | ||||
| { | { | ||||
| var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | ||||
| @@ -33,5 +75,11 @@ namespace Tensorflow.Eager | |||||
| dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle); | dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle); | ||||
| return dims; | return dims; | ||||
| } | } | ||||
| public override T[] ToArray<T>() | |||||
| { | |||||
| Resolve(); | |||||
| return base.ToArray<T>(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -38,8 +38,8 @@ namespace Tensorflow.NumPy | |||||
| tensor = tf.defaultSession.eval(tensor); | tensor = tf.defaultSession.eval(tensor); | ||||
| _handle = tensor.Handle; | _handle = tensor.Handle; | ||||
| } | } | ||||
| NewEagerTensorHandle(); | |||||
| NewEagerTensorHandle(); | |||||
| } | } | ||||
| public static NDArray Scalar<T>(T value) where T : unmanaged | public static NDArray Scalar<T>(T value) where T : unmanaged | ||||
| @@ -57,7 +57,9 @@ namespace Tensorflow.NumPy | |||||
| void NewEagerTensorHandle() | void NewEagerTensorHandle() | ||||
| { | { | ||||
| if (_handle is not null) | if (_handle is not null) | ||||
| _eagerTensorHandle = new EagerTensor(_handle).EagerTensorHandle; | |||||
| { | |||||
| _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ namespace Tensorflow | |||||
| [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] | [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] | ||||
| public partial class Tensor | public partial class Tensor | ||||
| { | { | ||||
| public IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); | |||||
| public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); | |||||
| public Tensor() | public Tensor() | ||||
| { | { | ||||
| @@ -5,124 +5,88 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class Tensor | public partial class Tensor | ||||
| { | { | ||||
| public static explicit operator bool(Tensor tensor) | |||||
| public unsafe static explicit operator bool(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_BOOL); | |||||
| return *(bool*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_BOOL); | |||||
| return *(bool*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator sbyte(Tensor tensor) | |||||
| public unsafe static explicit operator sbyte(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT8); | |||||
| return *(sbyte*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT8); | |||||
| return *(sbyte*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator byte(Tensor tensor) | |||||
| public unsafe static explicit operator byte(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT8); | |||||
| return *(byte*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT8); | |||||
| return *(byte*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator ushort(Tensor tensor) | |||||
| public unsafe static explicit operator ushort(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT16); | |||||
| return *(ushort*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT16); | |||||
| return *(ushort*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator short(Tensor tensor) | |||||
| public unsafe static explicit operator short(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT16); | |||||
| return *(short*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT16); | |||||
| return *(short*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator int(Tensor tensor) | |||||
| public unsafe static explicit operator int(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT32); | |||||
| return *(int*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT32); | |||||
| return *(int*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator uint(Tensor tensor) | |||||
| public unsafe static explicit operator uint(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT32); | |||||
| return *(uint*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT32); | |||||
| return *(uint*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator long(Tensor tensor) | |||||
| public unsafe static explicit operator long(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT64); | |||||
| return *(long*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT64); | |||||
| return *(long*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator ulong(Tensor tensor) | |||||
| public unsafe static explicit operator ulong(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT64); | |||||
| return *(ulong*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT64); | |||||
| return *(ulong*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator float(Tensor tensor) | |||||
| public unsafe static explicit operator float(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||||
| return *(float*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||||
| return *(float*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator double(Tensor tensor) | |||||
| public unsafe static explicit operator double(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||||
| return *(double*)tensor.buffer; | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||||
| return *(double*)tensor.buffer; | |||||
| } | } | ||||
| public static explicit operator string(Tensor tensor) | |||||
| public unsafe static explicit operator string(Tensor tensor) | |||||
| { | { | ||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_STRING); | |||||
| return new string((char*)tensor.buffer, 0, (int)tensor.size); | |||||
| } | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_STRING); | |||||
| return new string((char*)tensor.buffer, 0, (int)tensor.size); | |||||
| } | } | ||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | [MethodImpl(MethodImplOptions.AggressiveInlining)] | ||||
| @@ -12,7 +12,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <typeparam name="T"></typeparam> | /// <typeparam name="T"></typeparam> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public unsafe T[] ToArray<T>() where T : unmanaged | |||||
| public virtual unsafe T[] ToArray<T>() where T : unmanaged | |||||
| { | { | ||||
| //Are the types matching? | //Are the types matching? | ||||
| if (typeof(T).as_tf_dtype() != dtype) | if (typeof(T).as_tf_dtype() != dtype) | ||||
| @@ -68,10 +68,10 @@ namespace Tensorflow | |||||
| /// The DType of elements in this tensor. | /// The DType of elements in this tensor. | ||||
| /// </summary> | /// </summary> | ||||
| public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); | public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); | ||||
| public ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); | |||||
| public virtual ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); | |||||
| public ulong dtypesize => (ulong)dtype.get_datatype_size(); | public ulong dtypesize => (ulong)dtype.get_datatype_size(); | ||||
| public ulong size => _handle == null ? 0 : bytesize / dtypesize; | public ulong size => _handle == null ? 0 : bytesize / dtypesize; | ||||
| public IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||||
| public virtual IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||||
| public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
| public int ndim => rank; | public int ndim => rank; | ||||
| @@ -86,7 +86,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public object Tag { get; set; } | public object Tag { get; set; } | ||||
| protected new SafeTensorHandle _handle; | protected new SafeTensorHandle _handle; | ||||
| public SafeTensorHandle Handle => _handle; | |||||
| public virtual SafeTensorHandle Handle => _handle; | |||||
| protected SafeEagerTensorHandle _eagerTensorHandle; | protected SafeEagerTensorHandle _eagerTensorHandle; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -114,18 +114,7 @@ namespace Tensorflow | |||||
| set | set | ||||
| { | { | ||||
| if (this is EagerTensor) | |||||
| { | |||||
| if(!shape.is_compatible_with(value)) | |||||
| throw new ValueError($"Tensor's shape is not compatible."); | |||||
| return; | |||||
| } | |||||
| if (value == null) | |||||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); | |||||
| else | |||||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); | |||||
| SetShapeInternal(value); | |||||
| tf.Status.Check(true); | tf.Status.Check(true); | ||||
| } | } | ||||
| } | } | ||||
| @@ -147,6 +136,14 @@ namespace Tensorflow | |||||
| return dims; | return dims; | ||||
| } | } | ||||
| protected virtual void SetShapeInternal(Shape value) | |||||
| { | |||||
| if (value == null) | |||||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); | |||||
| else | |||||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); | |||||
| } | |||||
| public int[] _shape_tuple() | public int[] _shape_tuple() | ||||
| { | { | ||||
| return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | ||||
| @@ -233,9 +233,9 @@ namespace Tensorflow | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (tensor.GetType() == typeof(EagerTensor)) | |||||
| if (tensor is EagerTensor eagerTensor) | |||||
| { | { | ||||
| if(tensor.dtype == TF_DataType.TF_INT64) | |||||
| if(tensor.dtype == tf.int64) | |||||
| return new Shape(tensor.ToArray<long>()); | return new Shape(tensor.ToArray<long>()); | ||||
| else | else | ||||
| return new Shape(tensor.ToArray<int>()); | return new Shape(tensor.ToArray<int>()); | ||||