diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 3d7ad4f4..2a49d1a6 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -114,22 +114,36 @@ namespace Tensorflow for (int i = 0; i < fetch_list.Length; i++) { var tensor = new Tensor(output_values[i]); - + Type type = tensor.dtype.as_numpy_datatype(); + var ndims = tensor.shape.Select(x => (int)x).ToArray(); + switch (tensor.dtype) { case TF_DataType.TF_STRING: - // wired, don't know why we have to start from offset 9. - var bytes = tensor.Data(); - result[i] = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); + { + // wired, don't know why we have to start from offset 9. + var bytes = tensor.Data(); + var output = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); + result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims); + } break; case TF_DataType.TF_FLOAT: - result[i] = *(float*)c_api.TF_TensorData(output_values[i]); + { + var output = *(float*)c_api.TF_TensorData(output_values[i]); + result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims); + } break; case TF_DataType.TF_INT16: - result[i] = *(short*)c_api.TF_TensorData(output_values[i]); + { + var output = *(short*)c_api.TF_TensorData(output_values[i]); + result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims); + } break; case TF_DataType.TF_INT32: - result[i] = *(int*)c_api.TF_TensorData(output_values[i]); + { + var output = *(int*)c_api.TF_TensorData(output_values[i]); + result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims); + } break; default: throw new NotImplementedException("can't get output"); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index e4907a81..4281e7f8 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -22,6 +22,8 @@ namespace Tensorflow public object value; public int value_index { get; } + private Status status = new Status(); + private TF_DataType _dtype = TF_DataType.DtInvalid; public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); @@ -33,8 +35,17 @@ namespace Tensorflow get { var dims = new long[rank]; - for (int i = 0; i < rank; i++) - dims[i] = c_api.TF_Dim(_handle, i); + + if (_handle == IntPtr.Zero) + { + c_api.TF_GraphGetTensorShape(op.Graph, _as_tf_output(), dims, rank, status); + status.Check(); + } + else + { + for (int i = 0; i < rank; i++) + dims[i] = c_api.TF_Dim(_handle, i); + } return dims; } @@ -48,7 +59,22 @@ namespace Tensorflow /// 3 3-Tensor (cube of numbers) /// n n-Tensor (you get the idea) /// - public int rank => _handle == IntPtr.Zero ? 0 : c_api.TF_NumDims(_handle); + public int rank + { + get + { + if (_handle == IntPtr.Zero) + { + var output = _as_tf_output(); + return c_api.TF_GraphGetTensorNumDims(op.Graph, output, status); + } + else + { + return c_api.TF_NumDims(_handle); + } + } + } + public int NDims => rank; /// @@ -182,6 +208,7 @@ namespace Tensorflow public void Dispose() { c_api.TF_DeleteTensor(_handle); + status.Dispose(); } public static implicit operator IntPtr(Tensor tensor) diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index ff5eb5eb..e03f5732 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -6,6 +6,16 @@ namespace Tensorflow { public static class dtypes { + public static Type as_numpy_datatype(this TF_DataType type) + { + switch (type) + { + case TF_DataType.TF_INT32: + return typeof(int); + default: + throw new NotImplementedException("as_numpy_datatype failed"); + } + } public static TF_DataType as_dtype(Type type) { TF_DataType dtype = TF_DataType.DtInvalid; diff --git a/test/TensorFlowNET.Examples/BasicOperations.cs b/test/TensorFlowNET.Examples/BasicOperations.cs index b4bd92cf..2b782fc9 100644 --- a/test/TensorFlowNET.Examples/BasicOperations.cs +++ b/test/TensorFlowNET.Examples/BasicOperations.cs @@ -19,7 +19,7 @@ namespace TensorFlowNET.Examples // Basic constant operations // The value returned by the constructor represents the output // of the Constant op. - var a = tf.constant(2); + /*var a = tf.constant(2); var b = tf.constant(3); // Launch the default graph. @@ -50,7 +50,7 @@ namespace TensorFlowNET.Examples // Run every operation with variable input Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); - } + }*/ // ---------------- // More in details: @@ -61,7 +61,39 @@ namespace TensorFlowNET.Examples // // The value returned by the constructor represents the output // of the Constant op. + var nd1 = np.array(3, 3).reshape(1, 2); + var matrix1 = tf.constant(nd1); + + // Create another Constant that produces a 2x1 matrix. + var nd2 = np.array(2, 2).reshape(2, 1); + var matrix2 = tf.constant(nd2); + + // Create a Matmul op that takes 'matrix1' and 'matrix2' as inputs. + // The returned value, 'product', represents the result of the matrix + // multiplication. + var product = tf.matmul(matrix1, matrix2); + // To run the matmul op we call the session 'run()' method, passing 'product' + // which represents the output of the matmul op. This indicates to the call + // that we want to get the output of the matmul op back. + // + // All inputs needed by the op are run automatically by the session. They + // typically are run in parallel. + // + // The call 'run(product)' thus causes the execution of threes ops in the + // graph: the two constants and matmul. + // + // The output of the op is returned in 'result' as a numpy `ndarray` object. + using (sess = tf.Session()) + { + var result = sess.run(product); + Console.WriteLine(result); + if((result as NDArray).Data()[0] != 12) + { + throw new Exception("BasicOperations error"); + } + // ==> [[ 12.]] + } } } }