Browse Source

fix NDArray creation in graph mode.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
9c2d5c4897
12 changed files with 171 additions and 91 deletions
  1. +27
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  2. +41
    -1
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  3. +0
    -57
      src/TensorFlowNET.Core/Numpy/NDArray.cs
  4. +25
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  5. +1
    -4
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  6. +12
    -10
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs
  8. +2
    -1
      src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
  9. +1
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  10. +52
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  11. +7
    -9
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  12. +2
    -7
      src/TensorFlowNET.Keras/Sequence.cs

+ 27
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -14,16 +14,42 @@ namespace Tensorflow.Eager
Resolve(); Resolve();
} }


#region scalar eager tensor
public EagerTensor(bool value) : base(value)
=> NewEagerTensorHandle(_handle);
public EagerTensor(byte value) : base(value)
=> NewEagerTensorHandle(_handle);
public EagerTensor(sbyte value) : base(value)
=> NewEagerTensorHandle(_handle);
public EagerTensor(short value) : base(value)
=> NewEagerTensorHandle(_handle);
public EagerTensor(int value) : base(value)
=> NewEagerTensorHandle(_handle);
public EagerTensor(uint value) : base(value)
=> NewEagerTensorHandle(_handle);
public EagerTensor(long value) : base(value)
=> NewEagerTensorHandle(_handle);
public EagerTensor(ulong value) : base(value)
=> NewEagerTensorHandle(_handle);
public EagerTensor(float value) : base(value)
=> NewEagerTensorHandle(_handle);
public EagerTensor(double value) : base(value)
=> NewEagerTensorHandle(_handle);
#endregion

public EagerTensor(object value,string device_name, TF_DataType dtype = TF_DataType.TF_UINT8) : base((float[])value) public EagerTensor(object value,string device_name, TF_DataType dtype = TF_DataType.TF_UINT8) : base((float[])value)
{ {
throw new NotImplementedException(""); throw new NotImplementedException("");
} }


public EagerTensor(object value, Shape shape = null, string device_name = null, TF_DataType dtype = TF_DataType.TF_UINT8) : base((float[])value)
public EagerTensor(object value, Shape? shape = null, string device_name = null, TF_DataType dtype = TF_DataType.TF_UINT8) : base((float[])value)
{ {
NewEagerTensorHandle(_handle); NewEagerTensorHandle(_handle);
} }


public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype)
=> NewEagerTensorHandle(_handle);

internal unsafe EagerTensor(string value) : base(value) internal unsafe EagerTensor(string value) : base(value)
=> NewEagerTensorHandle(_handle); => NewEagerTensorHandle(_handle);




+ 41
- 1
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -1,15 +1,55 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Eager;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.NumPy namespace Tensorflow.NumPy
{ {
public partial class NDArray public partial class NDArray
{ {
public NDArray(bool value) => _tensor = new EagerTensor(value);
public NDArray(byte value) => _tensor = new EagerTensor(value);
public NDArray(short value) => _tensor = new EagerTensor(value);
public NDArray(int value) => _tensor = new EagerTensor(value);
public NDArray(long value) => _tensor = new EagerTensor(value);
public NDArray(float value) => _tensor = new EagerTensor(value);
public NDArray(double value) => _tensor = new EagerTensor(value);

public NDArray(Array value, Shape? shape = null) => _tensor = new EagerTensor(value, shape);

public NDArray(Shape shape, NumpyDType dtype = NumpyDType.Float)
{
Initialize(shape, dtype: dtype);
}

public NDArray(Tensor value, Shape? shape = null)
{
if (shape is not null)
_tensor = tf.reshape(value, shape);
else
_tensor = value;

if (_tensor.TensorDataPointer == IntPtr.Zero)
_tensor = tf.get_default_session().eval(_tensor);
}

public static NDArray Scalar<T>(T value) where T : unmanaged
{
return value switch
{
bool val => new NDArray(val),
int val => new NDArray(val),
float val => new NDArray(val),
double val => new NDArray(val),
_ => throw new NotImplementedException("")
};
}

void Initialize(Shape shape, NumpyDType dtype = NumpyDType.Float) void Initialize(Shape shape, NumpyDType dtype = NumpyDType.Float)
{ {
_tensor = tf.zeros(shape, dtype: dtype.as_tf_dtype());
// _tensor = tf.zeros(shape, dtype: dtype.as_tf_dtype());
_tensor = new EagerTensor(shape, dtype: dtype.as_tf_dtype());
} }
} }
} }

+ 0
- 57
src/TensorFlowNET.Core/Numpy/NDArray.cs View File

@@ -17,63 +17,6 @@ namespace Tensorflow.NumPy
public Shape shape => _tensor.shape; public Shape shape => _tensor.shape;
public IntPtr data => _tensor.TensorDataPointer; public IntPtr data => _tensor.TensorDataPointer;


public NDArray(bool value)
{
_tensor = ops.convert_to_tensor(value);
}

public NDArray(byte value)
{
_tensor = ops.convert_to_tensor(value);
}

public NDArray(int value)
{
_tensor = ops.convert_to_tensor(value);
}

public NDArray(float value)
{
_tensor = ops.convert_to_tensor(value);
}

public NDArray(double value)
{
_tensor = ops.convert_to_tensor(value);
}

public NDArray(Array value, Shape shape = null)
{
_tensor = ops.convert_to_tensor(value);
}

public NDArray(Type dtype, Shape shape)
{

}

public NDArray(Shape shape, NumpyDType dtype = NumpyDType.Float)
{
Initialize(shape, dtype: dtype);
}

public NDArray(Tensor value, Shape? shape = null)
{
if (shape is not null)
_tensor = tf.reshape(value, shape);
else
_tensor = value;
}

public static NDArray Scalar<T>(T value) where T : unmanaged
{
return value switch
{
bool b => new NDArray(b),
_ => throw new NotImplementedException("")
};
}

public T GetValue<T>(int index) where T : unmanaged public T GetValue<T>(int index) where T : unmanaged
=> _tensor.ToArray<T>()[index]; => _tensor.ToArray<T>()[index];
public T GetAtIndex<T>(int index) where T : unmanaged public T GetAtIndex<T>(int index) where T : unmanaged


+ 25
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -245,6 +245,31 @@ namespace Tensorflow
return result; return result;
} }


public unsafe Tensor eval(Tensor tensor)
{
var status = tf.Status;

var output_values = new IntPtr[1];
var fetch_list = new[] { tensor._as_tf_output() };

c_api.TF_SessionRun(_handle,
run_options: null,
inputs: new TF_Output[0],
input_values: new IntPtr[0],
ninputs: 0,
outputs: fetch_list,
output_values: output_values,
noutputs: 1,
target_opers: new IntPtr[0],
ntargets: 0,
run_metadata: IntPtr.Zero,
status: status.Handle);

status.Check(true);

return new Tensor(output_values[0]);
}

private static unsafe NDArray fetchValue(IntPtr output) private static unsafe NDArray fetchValue(IntPtr output)
{ {
var tensor = new Tensor(output); var tensor = new Tensor(output);


+ 1
- 4
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -78,10 +78,7 @@ namespace Tensorflow
{ {
var value = tensor_values[j]; var value = tensor_values[j];
j += 1; j += 1;
if (value.ndim == 0)
full_values.Add(value);
else
full_values.Add(value[np.arange(0, (int)value.dims[0])]);
full_values.Add(value);
} }
i += 1; i += 1;
} }


+ 12
- 10
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -64,6 +64,9 @@ namespace Tensorflow
#endif #endif
} }


unsafe internal Tensor(Shape shape, TF_DataType dtype)
=> _handle = TF_NewTensor(shape, dtype, null);

internal Tensor(Array array, Shape? shape = null) internal Tensor(Array array, Shape? shape = null)
=> InitTensor(array, shape); => InitTensor(array, shape);


@@ -71,41 +74,40 @@ namespace Tensorflow
{ {
shape = shape ?? array.GetShape(); shape = shape ?? array.GetShape();
var dtype = array.GetType().GetElementType().as_tf_dtype(); var dtype = array.GetType().GetElementType().as_tf_dtype();
var length = (ulong)(array.Length * dtype.get_datatype_size());


switch (array) switch (array)
{ {
case bool[] val: case bool[] val:
fixed (void* addr = &val[0]) fixed (void* addr = &val[0])
_handle = TF_NewTensor(shape, dtype, addr, length);
_handle = TF_NewTensor(shape, dtype, addr);
break; break;
case int[] val: case int[] val:
fixed (void* addr = &val[0]) fixed (void* addr = &val[0])
_handle = TF_NewTensor(shape, dtype, addr, length);
_handle = TF_NewTensor(shape, dtype, addr);
break; break;
case int[,] val: case int[,] val:
fixed (void* addr = &val[0, 0]) fixed (void* addr = &val[0, 0])
_handle = TF_NewTensor(shape, dtype, addr, length);
_handle = TF_NewTensor(shape, dtype, addr);
break; break;
case long[] val: case long[] val:
fixed (void* addr = &val[0]) fixed (void* addr = &val[0])
_handle = TF_NewTensor(shape, dtype, addr, length);
_handle = TF_NewTensor(shape, dtype, addr);
break; break;
case float[] val: case float[] val:
fixed (void* addr = &val[0]) fixed (void* addr = &val[0])
_handle = TF_NewTensor(shape, dtype, addr, length);
_handle = TF_NewTensor(shape, dtype, addr);
break; break;
case float[,] val: case float[,] val:
fixed (void* addr = &val[0, 0]) fixed (void* addr = &val[0, 0])
_handle = TF_NewTensor(shape, dtype, addr, length);
_handle = TF_NewTensor(shape, dtype, addr);
break; break;
case double[] val: case double[] val:
fixed (void* addr = &val[0]) fixed (void* addr = &val[0])
_handle = TF_NewTensor(shape, dtype, addr, length);
_handle = TF_NewTensor(shape, dtype, addr);
break; break;
case double[,] val: case double[,] val:
fixed (void* addr = &val[0, 0]) fixed (void* addr = &val[0, 0])
_handle = TF_NewTensor(shape, dtype, addr, length);
_handle = TF_NewTensor(shape, dtype, addr);
break; break;
default: default:
throw new NotImplementedException(""); throw new NotImplementedException("");
@@ -131,7 +133,7 @@ namespace Tensorflow
} }


public unsafe Tensor(NDArray nd) public unsafe Tensor(NDArray nd)
=> _handle = TF_NewTensor(nd.shape, nd.dtype.as_tf_dtype(), nd.data.ToPointer(), nd.size * nd.dtypesize);
=> _handle = TF_NewTensor(nd.shape, nd.dtype.as_tf_dtype(), nd.data.ToPointer());


#region scala #region scala
public Tensor(bool value) => _handle = TF_NewTensor(value); public Tensor(bool value) => _handle = TF_NewTensor(value);


+ 1
- 1
src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow
} }


public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone()); public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone());
public static implicit operator Shape(TensorShape shape) => new Shape((long[])shape.dims.Clone());
public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone());


public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes
public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims);


+ 2
- 1
src/TensorFlowNET.Core/Tensors/c_api.tensor.cs View File

@@ -104,8 +104,9 @@ namespace Tensorflow
return c_api.TF_NewTensor(dataType, dims, num_dims, data, len, EmptyDeallocator, DeallocatorArgs.Empty); return c_api.TF_NewTensor(dataType, dims, num_dims, data, len, EmptyDeallocator, DeallocatorArgs.Empty);
} }


public static unsafe IntPtr TF_NewTensor(Shape shape, TF_DataType dtype, void* data, ulong length)
public static unsafe IntPtr TF_NewTensor(Shape shape, TF_DataType dtype, void* data)
{ {
var length = shape.size * (ulong)dtype.get_datatype_size();
var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, length); var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, length);
var tensor = TF_TensorData(handle); var tensor = TF_TensorData(handle);
System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length); System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length);


+ 1
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -98,6 +98,7 @@ namespace Tensorflow
attrs: attrs, attrs: attrs,
name: name); name: name);


var o = op.outputs;
return op.outputs[0]; return op.outputs[0];
} }




+ 52
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -182,6 +182,58 @@ namespace Tensorflow
return dtype.Value; return dtype.Value;
} }


public static TF_DataType tf_dtype_from_name(string name)
{
TF_DataType dtype = TF_DataType.DtInvalid;
switch (name.ToLower())
{
case "char":
dtype = TF_DataType.TF_UINT8;
break;
case "boolean":
dtype = TF_DataType.TF_BOOL;
break;
case "sbyte":
dtype = TF_DataType.TF_INT8;
break;
case "byte":
dtype = TF_DataType.TF_UINT8;
break;
case "int16":
dtype = TF_DataType.TF_INT16;
break;
case "uint16":
dtype = TF_DataType.TF_UINT16;
break;
case "int32":
dtype = TF_DataType.TF_INT32;
break;
case "uint32":
dtype = TF_DataType.TF_UINT32;
break;
case "int64":
dtype = TF_DataType.TF_INT64;
break;
case "uint64":
dtype = TF_DataType.TF_UINT64;
break;
case "single":
dtype = TF_DataType.TF_FLOAT;
break;
case "double":
dtype = TF_DataType.TF_DOUBLE;
break;
case "complex":
dtype = TF_DataType.TF_COMPLEX128;
break;
case "string":
dtype = TF_DataType.TF_STRING;
break;
}

return dtype;
}

public static DataType as_datatype_enum(this TF_DataType type) public static DataType as_datatype_enum(this TF_DataType type)
{ {
return (DataType)type; return (DataType)type;


+ 7
- 9
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -101,7 +101,7 @@ namespace Tensorflow
/// <param name="verify_shape"></param> /// <param name="verify_shape"></param>
/// <param name="allow_broadcast"></param> /// <param name="allow_broadcast"></param>
/// <returns></returns> /// <returns></returns>
public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, int[]? shape = null, bool verify_shape = false, bool allow_broadcast = false)
public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, Shape? shape = null, bool verify_shape = false, bool allow_broadcast = false)
{ {
if (allow_broadcast && verify_shape) if (allow_broadcast && verify_shape)
throw new ValueError("allow_broadcast and verify_shape are not both allowed."); throw new ValueError("allow_broadcast and verify_shape are not both allowed.");
@@ -109,10 +109,11 @@ namespace Tensorflow
return tp; return tp;


dtype = values.GetType().as_tf_dtype(); dtype = values.GetType().as_tf_dtype();
shape = shape ?? values.GetShape();
var tensor_proto = new TensorProto var tensor_proto = new TensorProto
{ {
Dtype = dtype.as_datatype_enum(), Dtype = dtype.as_datatype_enum(),
TensorShape = values.GetShape().as_shape_proto()
TensorShape = shape.as_shape_proto()
}; };


// scalar // scalar
@@ -141,8 +142,6 @@ namespace Tensorflow
default: default:
throw new Exception("make_tensor_proto Not Implemented"); throw new Exception("make_tensor_proto Not Implemented");
} }

return tensor_proto;
} }
else if (dtype == TF_DataType.TF_STRING && !(values is NDArray)) else if (dtype == TF_DataType.TF_STRING && !(values is NDArray))
{ {
@@ -154,15 +153,14 @@ namespace Tensorflow
tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x)));
else if (values is byte[] byte_values) else if (values is byte[] byte_values)
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values);

return tensor_proto;
} }
else if(values is Array array) else if(values is Array array)
{ {
// array // array
/*byte[] bytes = array.ToByteArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray());
return tensor_proto;*/
var len = dtype.get_datatype_size() * (int)shape.size;
byte[] bytes = new byte[len];
System.Buffer.BlockCopy(array, 0, bytes, 0, len);
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
} }


return tensor_proto; return tensor_proto;


+ 2
- 7
src/TensorFlowNET.Keras/Sequence.cs View File

@@ -54,8 +54,8 @@ namespace Tensorflow.Keras
if (value == null) if (value == null)
value = 0f; value = 0f;


var type = getNPType(dtype);
var nd = new NDArray(type, new Shape(length.Count(), maxlen.Value));
var type = dtypes.tf_dtype_from_name(dtype);
var nd = new NDArray((length.Count(), maxlen.Value), dtype: type.as_numpy_typecode());


for (int i = 0; i < nd.dims[0]; i++) for (int i = 0; i < nd.dims[0]; i++)
{ {
@@ -71,10 +71,5 @@ namespace Tensorflow.Keras


return nd; return nd;
} }

private Type getNPType(string typeName)
{
return System.Type.GetType("NumSharp.np,NumSharp").GetField(typeName).GetValue(null) as Type;
}
} }
} }

Loading…
Cancel
Save