Browse Source

tf.zeros for dtype uint8

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
99fc01611e
2 changed files with 5 additions and 2 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Data/MnistModelLoader.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs

+ 2
- 2
src/TensorFlowNET.Core/Data/MnistModelLoader.cs View File

@@ -123,7 +123,7 @@ namespace Tensorflow

bytestream.Read(buf, 0, buf.Length);

var data = np.frombuffer(buf, (num_images, rows * cols), np.@byte);
var data = np.frombuffer(buf, (num_images, rows * cols), np.uint8);
return data;
}
}
@@ -146,7 +146,7 @@ namespace Tensorflow

bytestream.Read(buf, 0, buf.Length);

var labels = np.frombuffer(buf, new Shape(num_items), np.@byte);
var labels = np.frombuffer(buf, new Shape(num_items), np.uint8);

if (one_hot)
return DenseToOneHot(labels, num_classes);


+ 3
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -91,6 +91,9 @@ namespace Tensorflow
zeros = constant(0f);
break;
case TF_DataType.TF_INT8:
zeros = constant((sbyte)0);
break;
case TF_DataType.TF_UINT8:
zeros = constant((byte)0);
break;
default:


Loading…
Cancel
Save