From 1bf44ece63f1bce58477856e2bfa2eb0f61f768d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 7 Aug 2019 09:03:08 -0500 Subject: [PATCH] Logistic Regression can run but has accuracy and performance. --- src/TensorFlowHub/MnistDataSet.cs | 2 +- .../Sessions/_FetchMapper.cs | 21 +++++++++++++++---- src/TensorFlowNET.Core/Status/Status.cs | 2 +- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 2 +- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/TensorFlowHub/MnistDataSet.cs b/src/TensorFlowHub/MnistDataSet.cs index accc57e1..8ad0e687 100644 --- a/src/TensorFlowHub/MnistDataSet.cs +++ b/src/TensorFlowHub/MnistDataSet.cs @@ -24,7 +24,7 @@ namespace Tensorflow.Hub images = np.multiply(images, 1.0f / 255.0f); Data = images; - labels.astype(dataType); + labels = labels.astype(dataType); Labels = labels; } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index 69c5c204..eb465998 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -15,7 +15,9 @@ ******************************************************************************/ using NumSharp; +using System; using System.Collections.Generic; +using System.Linq; namespace Tensorflow { @@ -37,10 +39,21 @@ namespace Tensorflow public virtual NDArray build_results(List values) { - var type = values[0].GetType(); - var nd = new NDArray(type, values.Count); - nd.ReplaceData(values.ToArray()); - return nd; + // if they're all scalar value + bool isAllScalars = values.Count(x => x.ndim == 0) == values.Count; + if (isAllScalars) + { + var type = values[0].dtype; + switch(Type.GetTypeCode(type)) + { + case TypeCode.Single: + return np.array(values.Select(x => x.GetSingle(0)).ToArray()); + default: + throw new NotImplementedException("build_results"); + } + } + + return np.stack(values.ToArray()); } public virtual List unique_fetches() diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index 7eb2d7e3..cc4fdff1 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -55,7 +55,7 @@ namespace Tensorflow Console.WriteLine(Message); if (throwException) { - throw new Exception(Message); + // throw new Exception(Message); } } } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 2487e9d9..f7089a8e 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -57,7 +57,7 @@ namespace Tensorflow public static NDArray MakeNdarray(TensorProto tensor) { var shape = tensor.TensorShape.Dim.Select(x => (int)x.Size).ToArray(); - int num_elements = shape.Length == 0 ? NDArray.Scalar(1) : np.prod(shape); + int num_elements = np.prod(shape); var tensor_dtype = tensor.Dtype.as_numpy_dtype(); if (tensor.TensorContent.Length > 0)