Browse Source

Shape as_int_list

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
0b6e855439
9 changed files with 30 additions and 12 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Data/MnistModelLoader.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  3. +5
    -0
      src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs
  4. +6
    -0
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  5. +3
    -0
      src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs
  6. +0
    -3
      src/TensorFlowNET.Core/Numpy/Numpy.cs
  7. +5
    -0
      src/TensorFlowNET.Core/Numpy/Shape.cs
  8. +4
    -2
      src/TensorFlowNET.Core/Operations/array_ops.cs
  9. +4
    -4
      src/TensorFlowNET.Core/Tensors/tensor_util.cs

+ 2
- 2
src/TensorFlowNET.Core/Data/MnistModelLoader.cs View File

@@ -123,7 +123,7 @@ namespace Tensorflow


bytestream.Read(buf, 0, buf.Length); bytestream.Read(buf, 0, buf.Length);


var data = np.frombuffer(buf, np.@byte.as_system_dtype());
var data = np.frombuffer(buf, np.@byte);
data = data.reshape((num_images, rows, cols, 1)); data = data.reshape((num_images, rows, cols, 1));


return data; return data;
@@ -148,7 +148,7 @@ namespace Tensorflow


bytestream.Read(buf, 0, buf.Length); bytestream.Read(buf, 0, buf.Length);


var labels = np.frombuffer(buf, np.uint8.as_system_dtype());
var labels = np.frombuffer(buf, np.uint8);


if (one_hot) if (one_hot)
return DenseToOneHot(labels, num_classes); return DenseToOneHot(labels, num_classes);


+ 1
- 1
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -275,7 +275,7 @@ namespace Tensorflow
if (y.dtype.is_complex()) if (y.dtype.is_complex())
throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})");
var shape = array_ops.shape(y); var shape = array_ops.shape(y);
var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}");
var constant = constant_op.constant(1, y.dtype, name: $"grad_ys_{i}");
var fill = gen_array_ops.fill(shape, constant); var fill = gen_array_ops.fill(shape, constant);
new_grad_ys.append(fill); new_grad_ys.append(fill);
continue; continue;


+ 5
- 0
src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs View File

@@ -33,6 +33,11 @@ namespace Tensorflow.NumPy
return new NDArray(tensor); return new NDArray(tensor);
} }


public NDArray frombuffer(byte[] bytes, TF_DataType dtype)
{
throw new NotImplementedException("");
}

public NDArray linspace<T>(T start, T stop, int num = 50, bool endpoint = true, bool retstep = false, public NDArray linspace<T>(T start, T stop, int num = 50, bool endpoint = true, bool retstep = false,
TF_DataType dtype = TF_DataType.TF_DOUBLE, int axis = 0) TF_DataType dtype = TF_DataType.TF_DOUBLE, int axis = 0)
{ {


+ 6
- 0
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -18,6 +18,7 @@ namespace Tensorflow.NumPy
public NDArray(Array value, Shape? shape = null) => Init(value, shape); public NDArray(Array value, Shape? shape = null) => Init(value, shape);
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype); public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype);
public NDArray(Tensor value, Shape? shape = null) => Init(value, shape); public NDArray(Tensor value, Shape? shape = null) => Init(value, shape);
public NDArray(byte[] bytes, TF_DataType dtype) => Init(bytes, dtype);


public static NDArray Scalar<T>(T value) where T : unmanaged public static NDArray Scalar<T>(T value) where T : unmanaged
=> value switch => value switch
@@ -68,5 +69,10 @@ namespace Tensorflow.NumPy
_tensor = new Tensor(value.TensorDataPointer, shape ?? value.shape, value.dtype); _tensor = new Tensor(value.TensorDataPointer, shape ?? value.shape, value.dtype);
_tensor.SetReferencedByNDArray(); _tensor.SetReferencedByNDArray();
} }

void Init(byte[] bytes, TF_DataType dtype)
{

}
} }
} }

+ 3
- 0
src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs View File

@@ -33,6 +33,9 @@ namespace Tensorflow.NumPy
public static NDArray full<T>(Shape shape, T fill_value) public static NDArray full<T>(Shape shape, T fill_value)
=> new NDArray(tf.fill(tf.constant(shape), fill_value)); => new NDArray(tf.fill(tf.constant(shape), fill_value));


public static NDArray frombuffer(byte[] bytes, TF_DataType dtype)
=> tf.numpy.frombuffer(bytes, dtype);

public static NDArray linspace<T>(T start, T stop, int num = 50, bool endpoint = true, bool retstep = false, public static NDArray linspace<T>(T start, T stop, int num = 50, bool endpoint = true, bool retstep = false,
TF_DataType dtype = TF_DataType.TF_DOUBLE, int axis = 0) where T : unmanaged TF_DataType dtype = TF_DataType.TF_DOUBLE, int axis = 0) where T : unmanaged
=> tf.numpy.linspace(start, stop, num: num, endpoint: endpoint, retstep: retstep, dtype: dtype, axis: axis); => tf.numpy.linspace(start, stop, num: num, endpoint: endpoint, retstep: retstep, dtype: dtype, axis: axis);


+ 0
- 3
src/TensorFlowNET.Core/Numpy/Numpy.cs View File

@@ -56,9 +56,6 @@ namespace Tensorflow.NumPy
public static NDArray concatenate(NDArray[] arrays, int axis = 0) public static NDArray concatenate(NDArray[] arrays, int axis = 0)
=> throw new NotImplementedException(""); => throw new NotImplementedException("");


public static NDArray frombuffer(byte[] bytes, Type dtype)
=> throw new NotImplementedException("");

public static NDArray frombuffer(byte[] bytes, string dtype) public static NDArray frombuffer(byte[] bytes, string dtype)
=> throw new NotImplementedException(""); => throw new NotImplementedException("");




+ 5
- 0
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -214,6 +214,11 @@ namespace Tensorflow
return new Shape(new_dims.ToArray()); return new Shape(new_dims.ToArray());
} }


public int[] as_int_list()
{
return _dims.Select(x => (int)x).ToArray();
}

public void assert_has_rank(int rank) public void assert_has_rank(int rank)
{ {
if (rank != ndim) if (rank != ndim)


+ 4
- 2
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -591,8 +591,10 @@ namespace Tensorflow
var input_shape = input.shape; var input_shape = input.shape;
if (optimize && input.ndim > -1 && input_shape.IsFullyDefined) if (optimize && input.ndim > -1 && input_shape.IsFullyDefined)
{ {
var nd = np.array(input.shape.dims).astype(out_type.as_system_dtype());
return constant_op.constant(nd, name: name);
if(out_type == TF_DataType.TF_INT32)
return constant_op.constant(input.shape.as_int_list(), name: name);
else
return constant_op.constant(input.shape.dims, name: name);
} }
} }




+ 4
- 4
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -60,11 +60,11 @@ namespace Tensorflow


public static NDArray MakeNdarray(TensorProto tensor) public static NDArray MakeNdarray(TensorProto tensor)
{ {
var shape = tensor.TensorShape.Dim.Select(x => (int)x.Size).ToArray();
int num_elements = np.prod(shape);
var tensor_dtype = tensor.Dtype.as_numpy_dtype();
var shape = new Shape(tensor.TensorShape.Dim.Select(x => x.Size).ToArray());
var num_elements = shape.size;
var tensor_dtype = tensor.Dtype.as_tf_dtype();


if (shape.Length > 0 && tensor.TensorContent.Length > 0)
if (shape.ndim > 0 && tensor.TensorContent.Length > 0)
{ {
return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype).reshape(shape); return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype).reshape(shape);
} }


Loading…
Cancel
Save