diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs new file mode 100644 index 00000000..5e729a14 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs @@ -0,0 +1,15 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class Tensor + { + public object[] Flatten() + { + return new Tensor[] { this }; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 99a373c4..9f505419 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -39,7 +39,12 @@ namespace Tensorflow /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// [SuppressMessage("ReSharper", "ConvertToAutoProperty")] - public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike, ITensorOrTensorArray, IPackable + public partial class Tensor : DisposableObject, + ITensorOrOperation, + _TensorLike, + ITensorOrTensorArray, + IPackable, + ICanBeFlattened { private readonly int _id; private readonly Operation _op; diff --git a/src/TensorFlowNET.Core/Tensors/TensorArray.cs b/src/TensorFlowNET.Core/Tensors/TensorArray.cs index fe9e2d6d..369b9dc0 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorArray.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorArray.cs @@ -61,5 +61,11 @@ namespace Tensorflow public Tensor read(Tensor index, string name = null) => _implementation.read(index, name: name); + + public TensorArray write(Tensor index, Tensor value, string name = null) + => _implementation.write(index, value, name: name); + + public Tensor stack(string name = null) + => _implementation.stack(name: name); } } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 2a2a9bfa..59de20ac 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -33,6 +33,9 @@ namespace Tensorflow public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float64 = TF_DataType.TF_DOUBLE; + public static TF_DataType complex = TF_DataType.TF_COMPLEX; + public static TF_DataType complex64 = TF_DataType.TF_COMPLEX64; + public static TF_DataType complex128 = TF_DataType.TF_COMPLEX128; public static TF_DataType variant = TF_DataType.TF_VARIANT; public static TF_DataType resource = TF_DataType.TF_RESOURCE;