Browse Source

Logistic Regression can run but has accuracy and performance.

tags/v0.12
Oceania2018 6 years ago
parent
commit
1bf44ece63
4 changed files with 20 additions and 7 deletions
  1. +1
    -1
      src/TensorFlowHub/MnistDataSet.cs
  2. +17
    -4
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Status/Status.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs

+ 1
- 1
src/TensorFlowHub/MnistDataSet.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow.Hub
images = np.multiply(images, 1.0f / 255.0f); images = np.multiply(images, 1.0f / 255.0f);
Data = images; Data = images;


labels.astype(dataType);
labels = labels.astype(dataType);
Labels = labels; Labels = labels;
} }




+ 17
- 4
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -15,7 +15,9 @@
******************************************************************************/ ******************************************************************************/


using NumSharp; using NumSharp;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;


namespace Tensorflow namespace Tensorflow
{ {
@@ -37,10 +39,21 @@ namespace Tensorflow


public virtual NDArray build_results(List<NDArray> values) public virtual NDArray build_results(List<NDArray> values)
{ {
var type = values[0].GetType();
var nd = new NDArray(type, values.Count);
nd.ReplaceData(values.ToArray());
return nd;
// if they're all scalar value
bool isAllScalars = values.Count(x => x.ndim == 0) == values.Count;
if (isAllScalars)
{
var type = values[0].dtype;
switch(Type.GetTypeCode(type))
{
case TypeCode.Single:
return np.array(values.Select(x => x.GetSingle(0)).ToArray());
default:
throw new NotImplementedException("build_results");
}
}

return np.stack(values.ToArray());
} }


public virtual List<ITensorOrOperation> unique_fetches() public virtual List<ITensorOrOperation> unique_fetches()


+ 1
- 1
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow
Console.WriteLine(Message); Console.WriteLine(Message);
if (throwException) if (throwException)
{ {
throw new Exception(Message);
// throw new Exception(Message);
} }
} }
} }


+ 1
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -57,7 +57,7 @@ namespace Tensorflow
public static NDArray MakeNdarray(TensorProto tensor) public static NDArray MakeNdarray(TensorProto tensor)
{ {
var shape = tensor.TensorShape.Dim.Select(x => (int)x.Size).ToArray(); var shape = tensor.TensorShape.Dim.Select(x => (int)x.Size).ToArray();
int num_elements = shape.Length == 0 ? NDArray.Scalar(1) : np.prod(shape);
int num_elements = np.prod(shape);
var tensor_dtype = tensor.Dtype.as_numpy_dtype(); var tensor_dtype = tensor.Dtype.as_numpy_dtype();


if (tensor.TensorContent.Length > 0) if (tensor.TensorContent.Length > 0)


Loading…
Cancel
Save