From 100d769284696825e1b86fb34a01aff7297f3596 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 7 Mar 2020 07:43:16 -0600 Subject: [PATCH] TFE_Py_FastPathExecute --- src/TensorFlowNET.Core/APIs/tf.math.cs | 3 + src/TensorFlowNET.Core/Eager/EagerTensor.cs | 4 + src/TensorFlowNET.Core/Eager/c_api.eager.cs | 62 ++++- .../Eager/pywrap_tfe_src.cs | 92 +++++++ .../Graphs/Graph.Operation.cs | 11 + .../Operations/gen_math_ops.cs | 14 +- .../TensorFlow.Binding.csproj | 4 + .../Tensors/Tensor.Value.cs | 234 ++++++++++++++++++ src/TensorFlowNET.Core/Tensors/Tensor.cs | 213 ---------------- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 2 +- src/TensorFlowNET.Core/Tensors/constant_op.cs | 2 + src/TensorFlowNET.Core/Tensors/dtypes.cs | 2 + src/TensorFlowNET.Core/Tensors/tf.constant.cs | 34 +-- 13 files changed, 423 insertions(+), 254 deletions(-) create mode 100644 src/TensorFlowNET.Core/Tensors/Tensor.Value.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 6e093cb6..958f4e27 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -41,6 +41,9 @@ namespace Tensorflow public Tensor asin(Tensor x, string name = null) => gen_math_ops.asin(x, name); + public Tensor add(Tensor a, Tensor b, string name = null) + => gen_math_ops.add(a, b, name: name); + public Tensor add(Tx a, Ty b, string name = null) => gen_math_ops.add(a, b, name: name); diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index 636e9520..a659e0b6 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -14,6 +14,10 @@ namespace Tensorflow.Eager { } + public EagerTensor(int value, string device_name) : base(value) + { + } + public override string ToString() { switch (rank) diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 43f2e41d..15a872e0 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -10,14 +10,14 @@ namespace Tensorflow /// /// TFE_ContextOptions* [DllImport(TensorFlowLibName)] - public static extern IntPtr TFE_NewContextOptions(); + internal static extern IntPtr TFE_NewContextOptions(); /// /// Destroy an options object. /// /// TFE_ContextOptions* [DllImport(TensorFlowLibName)] - public static extern void TFE_DeleteContextOptions(IntPtr options); + internal static extern void TFE_DeleteContextOptions(IntPtr options); /// /// @@ -26,14 +26,14 @@ namespace Tensorflow /// TF_Status* /// TFE_Context* [DllImport(TensorFlowLibName)] - public static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status); + internal static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status); /// /// /// /// TFE_Context* [DllImport(TensorFlowLibName)] - public static extern void TFE_DeleteContext(IntPtr ctx); + internal static extern void TFE_DeleteContext(IntPtr ctx); /// /// Execute the operation defined by 'op' and return handles to computed @@ -44,7 +44,7 @@ namespace Tensorflow /// int* /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TFE_Execute(IntPtr op, IntPtr retvals, int[] num_retvals, IntPtr status); + internal static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status); /// /// @@ -54,14 +54,14 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); + internal static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); /// /// /// /// TFE_Op* [DllImport(TensorFlowLibName)] - public static extern void TFE_DeleteOp(IntPtr op); + internal static extern void TFE_DeleteOp(IntPtr op); /// /// @@ -70,7 +70,10 @@ namespace Tensorflow /// const char* /// TF_DataType [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value); + internal static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value); + + [DllImport(TensorFlowLibName)] + internal static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value); /// /// @@ -81,7 +84,7 @@ namespace Tensorflow /// const int /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status); + internal static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status); /// /// @@ -91,7 +94,16 @@ namespace Tensorflow /// const void* /// size_t [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length); + internal static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length); + + /// + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + internal static extern void TFE_OpSetDevice(IntPtr op, string device_name, IntPtr status); /// /// @@ -100,7 +112,7 @@ namespace Tensorflow /// TFE_TensorHandle* /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status); + internal static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status); /// /// @@ -108,6 +120,32 @@ namespace Tensorflow /// const tensorflow::Tensor& /// TFE_TensorHandle* [DllImport(TensorFlowLibName)] - public static extern IntPtr TFE_NewTensorHandle(IntPtr t); + internal static extern IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status); + + /// + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + internal static extern IntPtr TFE_DeleteTensorHandle(IntPtr t, IntPtr status); + + /// + /// + /// + /// TFE_TensorHandle* + /// + [DllImport(TensorFlowLibName)] + internal static extern TF_DataType TFE_TensorHandleDataType(IntPtr h); + + /// + /// This function will block till the operation that produces `h` has completed. + /// + /// TFE_TensorHandle* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + internal static extern int TFE_TensorHandleNumDims(IntPtr h, IntPtr status); } } diff --git a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs index caf9dc57..e1ea4744 100644 --- a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs +++ b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs @@ -1,5 +1,7 @@ using System.Collections.Generic; using System.Linq; +using System; +using static Tensorflow.OpDef.Types; namespace Tensorflow.Eager { @@ -8,6 +10,96 @@ namespace Tensorflow.Eager /// public class pywrap_tfe_src { + public static EagerTensor TFE_Py_FastPathExecute(Context ctx, + string device_name, + string opName, + string name, + params Tensor[] inputs) + { + IntPtr op = IntPtr.Zero; + var attr_list_sizes = new Dictionary(); + using (var status = new Status()) + { + op = c_api.TFE_NewOp(ctx, opName, status); + + var op_def = Graph.TFE_GetOpDef(opName); + + // SetOpAttrWithDefaults + c_api.TFE_OpSetDevice(op, "", status); + + for (int i = 0; i < op_def.InputArg.Count; i++) + { + var input_arg = op_def.InputArg[i]; + if (!string.IsNullOrEmpty(input_arg.NumberAttr)) + { + c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, 0); + attr_list_sizes[input_arg.NumberAttr] = 0; + } + else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) + { + + } + else + { + // The item is a single item. + AddInputToOp(inputs[i], true, input_arg, op, status); + } + } + + int num_retvals = 0; + for (int i = 0; i < op_def.OutputArg.Count; i++) + { + var output_arg = op_def.OutputArg[i]; + var delta = 1; + if (!string.IsNullOrEmpty(output_arg.NumberAttr)) + delta = attr_list_sizes[output_arg.NumberAttr]; + else if (!string.IsNullOrEmpty(output_arg.TypeListAttr)) + delta = attr_list_sizes[output_arg.TypeListAttr]; + if(delta < 0) + throw new RuntimeError("Attributes suggest that the size of an output list is less than 0"); + num_retvals += delta; + } + + var retVals = new IntPtr[num_retvals]; + c_api.TFE_Execute(op, retVals, ref num_retvals, status); + + var h = c_api.TFE_NewTensorHandle(retVals[0], status); + var data = new Tensor(h); + status.Check(true); + } + + throw new NotImplementedException(""); + } + + /// + /// Adds input and type attr to the op, and to the list of flattened + /// inputs/attrs. + /// + /// + /// + /// + /// + /// + /// + private static bool AddInputToOp(Tensor input, + bool add_type_attr, + ArgDef input_arg, + IntPtr op, + Status status) + { + var input_handle = c_api.TFE_NewTensorHandle(input, status); + + if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) + { + var dtype = c_api.TFE_TensorHandleDataType(input_handle); + c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); + } + + c_api.TFE_OpAddInput(op, input_handle, status); + status.Check(true); + return true; + } + public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary attrs, Tensor[] results, string name = null) { var input_ids = inputs.Select(x => x.Id).ToArray(); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index a826d2f6..cd86a7b3 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -35,6 +35,17 @@ namespace Tensorflow } } + public static OpDef TFE_GetOpDef(string type) + { + IntPtr handle = tf.get_default_graph(); + using (var buffer = new Buffer()) + using (var status = new Status()) + { + c_api.TF_GraphGetOpDef(handle, type, buffer, status); + return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } + } + public OperationDescription NewOperation(string opType, string opName) { return c_api.TF_NewOperation(_handle, opType, opName); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 02bc1ada..cb5b0eb6 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow @@ -139,9 +140,20 @@ namespace Tensorflow return _op.outputs[0]; } + public static EagerTensor add(Tensor x, Tensor y, string name = null) + { + // _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); + + if (tf.context.executing_eagerly()) + { + var _result = pywrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Add", name, new[] { x, y }); + } + + return null; + } + public static Tensor add(Tx x, Ty y, string name = null) { - // forward_compatible(2019, 6, 25): var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); return _op.output; diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index 1518bccb..0fcb21a4 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -71,4 +71,8 @@ https://tensorflownet.readthedocs.io + + + + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs new file mode 100644 index 00000000..40d10c4f --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -0,0 +1,234 @@ +using NumSharp; +using NumSharp.Backends.Unmanaged; +using NumSharp.Utilities; +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public partial class Tensor + { + [Obsolete("Please use ToArray() instead.", false)] + public T[] Data() where T : unmanaged + { + return ToArray(); + } + + /// + /// + /// + /// + /// + public T[] ToArray() where T : unmanaged + { + //Are the types matching? + if (typeof(T).as_dtype() == dtype) + { + if (NDims == 0 && size == 1) //is it a scalar? + { + unsafe + { + return new T[] { *(T*)buffer }; + } + } + + //types match, no need to perform cast + var ret = new T[size]; + unsafe + { + var len = (long)size; + fixed (T* dst = ret) + { + //T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called. + var src = (T*)buffer; + len *= ((long)itemsize); + System.Buffer.MemoryCopy(src, dst, len, len); + } + } + + return ret; + } + else + { + + //types do not match, need to perform cast + if (NDims == 0 && size == 1) //is it a scalar? + { + unsafe + { +#if _REGEN + #region Compute + switch (dtype.as_numpy_dtype().GetTypeCode()) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: return new T[] {Converts.ChangeType(*(#2*) buffer)}; + % + case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this)}; + default: + throw new NotSupportedException(); + } + #endregion +#else + #region Compute + switch (dtype.as_numpy_dtype().GetTypeCode()) + { + case NPTypeCode.Boolean: return new T[] { Converts.ChangeType(*(bool*)buffer) }; + case NPTypeCode.Byte: return new T[] { Converts.ChangeType(*(byte*)buffer) }; + case NPTypeCode.Int16: return new T[] { Converts.ChangeType(*(short*)buffer) }; + case NPTypeCode.UInt16: return new T[] { Converts.ChangeType(*(ushort*)buffer) }; + case NPTypeCode.Int32: return new T[] { Converts.ChangeType(*(int*)buffer) }; + case NPTypeCode.UInt32: return new T[] { Converts.ChangeType(*(uint*)buffer) }; + case NPTypeCode.Int64: return new T[] { Converts.ChangeType(*(long*)buffer) }; + case NPTypeCode.UInt64: return new T[] { Converts.ChangeType(*(ulong*)buffer) }; + case NPTypeCode.Char: return new T[] { Converts.ChangeType(*(char*)buffer) }; + case NPTypeCode.Double: return new T[] { Converts.ChangeType(*(double*)buffer) }; + case NPTypeCode.Single: return new T[] { Converts.ChangeType(*(float*)buffer) }; + case NPTypeCode.String: return new T[] { Converts.ChangeType((string)this) }; + default: + throw new NotSupportedException(); + } + #endregion +#endif + } + } + + var ret = new T[size]; + unsafe + { + var len = (long)size; + fixed (T* dstRet = ret) + { + T* dst = dstRet; //local stack copy + +#if _REGEN + #region Compute + switch (dtype.as_numpy_dtype().GetTypeCode()) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + % + default: + throw new NotSupportedException(); + } + #endregion +#else + #region Compute + switch (dtype.as_numpy_dtype().GetTypeCode()) + { + case NPTypeCode.Boolean: new UnmanagedMemoryBlock((bool*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Byte: new UnmanagedMemoryBlock((byte*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int16: new UnmanagedMemoryBlock((short*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt16: new UnmanagedMemoryBlock((ushort*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int32: new UnmanagedMemoryBlock((int*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt32: new UnmanagedMemoryBlock((uint*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int64: new UnmanagedMemoryBlock((long*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt64: new UnmanagedMemoryBlock((ulong*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Char: new UnmanagedMemoryBlock((char*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Double: new UnmanagedMemoryBlock((double*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Single: new UnmanagedMemoryBlock((float*)buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To + default: + throw new NotSupportedException(); + } + #endregion +#endif + + } + } + + return ret; + } + } + + /// + /// Copy of the contents of this Tensor into a NumPy array or scalar. + /// + /// + /// A NumPy array of the same shape and dtype or a NumPy scalar, if this + /// Tensor has rank 0. + /// + public NDArray numpy() + { + if(NDims == 0) + { + return GetScalar(dtype); + } + else + { + throw new NotImplementedException("numpy not implemented when ndim > 0"); + } + } + + private unsafe NDArray GetScalar(TF_DataType dtype) + { + switch(dtype) + { + case TF_DataType.TF_STRING: + return StringData()[0]; + case TF_DataType.TF_INT32: + return *(int*)buffer; + default: + return BufferToArray(); + } + } + + /// + /// Copies the memory of current buffer onto newly allocated array. + /// + /// + public unsafe byte[] BufferToArray() + { + // ReSharper disable once LocalVariableHidesMember + var bytesize = (long)this.bytesize; + var data = new byte[bytesize]; + fixed (byte* dst = data) + System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize); + + return data; + } + + /// + /// Extracts string array from current Tensor. + /// + /// When != TF_DataType.TF_STRING + public unsafe string[] StringData() + { + if (dtype != TF_DataType.TF_STRING) + throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); + + // + // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. + // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] + // + long size = 1; + foreach (var s in TensorShape.dims) + size *= s; + + var buffer = new byte[size][]; + var src = c_api.TF_TensorData(_handle); + var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); + src += (int)(size * 8); + using (var status = new Status()) + { + for (int i = 0; i < buffer.Length; i++) + { + IntPtr dst = IntPtr.Zero; + UIntPtr dstLen = UIntPtr.Zero; + var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); + status.Check(true); + buffer[i] = new byte[(int)dstLen]; + Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); + src += (int)read; + } + } + + var _str = new string[buffer.Length]; + for (int i = 0; i < _str.Length; i++) + _str[i] = Encoding.UTF8.GetString(buffer[i]); + + return _str; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 3a681499..fbd9f7cd 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -20,13 +20,9 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; -using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; -using NumSharp.Backends; -using NumSharp.Backends.Unmanaged; -using NumSharp.Utilities; using Tensorflow.Framework; #if SERIALIZABLE using Newtonsoft.Json; @@ -249,215 +245,6 @@ namespace Tensorflow return _tf_output.Value; } - [Obsolete("Please use ToArray() instead.", false)] - public T[] Data() where T : unmanaged - { - return ToArray(); - } - - /// - /// - /// - /// - /// - public T[] ToArray() where T : unmanaged - { - //Are the types matching? - if (typeof(T).as_dtype() == dtype) - { - if (NDims == 0 && size == 1) //is it a scalar? - { - unsafe - { - return new T[] {*(T*) buffer}; - } - } - - //types match, no need to perform cast - var ret = new T[size]; - unsafe - { - var len = (long) size; - fixed (T* dst = ret) - { - //T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called. - var src = (T*) buffer; - len *= ((long) itemsize); - System.Buffer.MemoryCopy(src, dst, len, len); - } - } - - return ret; - } else - { - - //types do not match, need to perform cast - if (NDims == 0 && size == 1) //is it a scalar? - { - unsafe - { -#if _REGEN - #region Compute - switch (dtype.as_numpy_dtype().GetTypeCode()) - { - %foreach supported_dtypes,supported_dtypes_lowercase% - case NPTypeCode.#1: return new T[] {Converts.ChangeType(*(#2*) buffer)}; - % - case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this)}; - default: - throw new NotSupportedException(); - } - #endregion -#else - #region Compute - switch (dtype.as_numpy_dtype().GetTypeCode()) - { - case NPTypeCode.Boolean: return new T[] {Converts.ChangeType(*(bool*) buffer)}; - case NPTypeCode.Byte: return new T[] {Converts.ChangeType(*(byte*) buffer)}; - case NPTypeCode.Int16: return new T[] {Converts.ChangeType(*(short*) buffer)}; - case NPTypeCode.UInt16: return new T[] {Converts.ChangeType(*(ushort*) buffer)}; - case NPTypeCode.Int32: return new T[] {Converts.ChangeType(*(int*) buffer)}; - case NPTypeCode.UInt32: return new T[] {Converts.ChangeType(*(uint*) buffer)}; - case NPTypeCode.Int64: return new T[] {Converts.ChangeType(*(long*) buffer)}; - case NPTypeCode.UInt64: return new T[] {Converts.ChangeType(*(ulong*) buffer)}; - case NPTypeCode.Char: return new T[] {Converts.ChangeType(*(char*) buffer)}; - case NPTypeCode.Double: return new T[] {Converts.ChangeType(*(double*) buffer)}; - case NPTypeCode.Single: return new T[] {Converts.ChangeType(*(float*) buffer)}; - case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this)}; - default: - throw new NotSupportedException(); - } - #endregion -#endif - } - } - - var ret = new T[size]; - unsafe - { - var len = (long) size; - fixed (T* dstRet = ret) - { - T* dst = dstRet; //local stack copy - -#if _REGEN - #region Compute - switch (dtype.as_numpy_dtype().GetTypeCode()) - { - %foreach supported_dtypes,supported_dtypes_lowercase% - case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - % - default: - throw new NotSupportedException(); - } - #endregion -#else - #region Compute - switch (dtype.as_numpy_dtype().GetTypeCode()) - { - case NPTypeCode.Boolean: new UnmanagedMemoryBlock((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Byte: new UnmanagedMemoryBlock((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Int16: new UnmanagedMemoryBlock((short*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.UInt16: new UnmanagedMemoryBlock((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Int32: new UnmanagedMemoryBlock((int*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.UInt32: new UnmanagedMemoryBlock((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Int64: new UnmanagedMemoryBlock((long*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.UInt64: new UnmanagedMemoryBlock((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Char: new UnmanagedMemoryBlock((char*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Double: new UnmanagedMemoryBlock((double*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Single: new UnmanagedMemoryBlock((float*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To - default: - throw new NotSupportedException(); - } - #endregion -#endif - - } - } - - return ret; - } - } - - /// - /// Copy of the contents of this Tensor into a NumPy array or scalar. - /// - /// - /// A NumPy array of the same shape and dtype or a NumPy scalar, if this - /// Tensor has rank 0. - /// - public NDArray numpy() - { - switch (dtype) - { - case TF_DataType.TF_STRING: - return StringData()[0]; - default: - return BufferToArray(); - } - } - - /// - /// Copies the memory of current buffer onto newly allocated array. - /// - /// - public byte[] BufferToArray() - { - unsafe - { - // ReSharper disable once LocalVariableHidesMember - var bytesize = (long)this.bytesize; - var data = new byte[bytesize]; - fixed (byte* dst = data) - System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize); - - return data; - } - } - - /// - /// Extracts string array from current Tensor. - /// - /// When != TF_DataType.TF_STRING - public unsafe string[] StringData() - { - if (dtype != TF_DataType.TF_STRING) - throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); - - // - // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. - // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] - // - long size = 1; - foreach (var s in TensorShape.dims) - size *= s; - - var buffer = new byte[size][]; - var src = c_api.TF_TensorData(_handle); - var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); - src += (int) (size * 8); - for (int i = 0; i < buffer.Length; i++) - { - using (var status = new Status()) - { - IntPtr dst = IntPtr.Zero; - UIntPtr dstLen = UIntPtr.Zero; - var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); - status.Check(true); - buffer[i] = new byte[(int) dstLen]; - Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); - src += (int) read; - } - } - - var _str = new string[buffer.Length]; - for (int i = 0; i < _str.Length; i++) - _str[i] = Encoding.UTF8.GetString(buffer[i]); - - return _str; - } - public Tensor MaybeMove() { var tensor = c_api.TF_TensorMaybeMove(_handle); diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 5b521ef4..45e76526 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -252,7 +252,7 @@ namespace Tensorflow public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); - public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes + public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); public static explicit operator int(TensorShape shape) => shape.size; diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 961b36ac..2635f1d4 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -86,6 +86,8 @@ namespace Tensorflow { case string str: return new EagerTensor(str, ctx.device_name); + case int int32: + return new EagerTensor(int32, ctx.device_name); default: throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}"); } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 65c8fb70..7d0f618d 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -201,6 +201,8 @@ namespace Tensorflow => type switch { TF_DataType.TF_STRING => "string", + TF_DataType.TF_INT32 => "int32", + TF_DataType.TF_FLOAT => "float32", _ => type.ToString() }; diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs index 8ee1a531..8e30524b 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -31,33 +31,13 @@ namespace Tensorflow public Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, TensorShape shape = null, - string name = "Const") - { - switch (value) - { - case string str: - return constant_op._constant_impl(str, - @string, - null, - name, - verify_shape: false, - allow_broadcast: true); - case float val: - return constant_op._constant_impl(value, - float32, - new int[] { (int)shape }, - name, - verify_shape: false, - allow_broadcast: true); - default: - return constant_op._constant_impl(value, - dtype, - shape, - name, - verify_shape: false, - allow_broadcast: true); - } - } + string name = "Const") + => constant_op._constant_impl(value, + dtype, + shape, + name, + verify_shape: false, + allow_broadcast: true); public Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) => array_ops.zeros(shape, dtype, name);