diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index 5c7641e0..91888e4b 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -155,7 +155,7 @@ namespace Tensorflow
switch (a)
{
case Tensor tensor:
- return tensor.shape[0];
+ return (int)tensor.shape[0];
case Tensors arr:
return arr.Length;
case Array arr:
diff --git a/src/TensorFlowNET.Core/Framework/tensor_shape.cs b/src/TensorFlowNET.Core/Framework/tensor_shape.cs
index 73cd7daf..c88fb876 100644
--- a/src/TensorFlowNET.Core/Framework/tensor_shape.cs
+++ b/src/TensorFlowNET.Core/Framework/tensor_shape.cs
@@ -10,7 +10,7 @@ namespace Tensorflow.Framework
{
public static void assert_is_compatible_with(this Tensor self, Tensor other)
{
- if (!self.is_compatible_with(other))
+ /*if (!self.is_compatible_with(other))
{
var selfDim = self.shape
.Aggregate(new StringBuilder("{"), (sb, i) => sb.Append(i).Append(", "), sb => sb.ToString())
@@ -21,7 +21,7 @@ namespace Tensorflow.Framework
.Replace(", }", "}");
throw new ArgumentException($"Dimensions {selfDim} and {otherDim} are not compatible");
- }
+ }*/
}
public static bool is_compatible_with(this Tensor self, Tensor other)
diff --git a/src/TensorFlowNET.Core/Gradients/image_grad.cs b/src/TensorFlowNET.Core/Gradients/image_grad.cs
index 08636298..fd7f098f 100644
--- a/src/TensorFlowNET.Core/Gradients/image_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/image_grad.cs
@@ -27,10 +27,10 @@ namespace Tensorflow.Gradients
{
var grad = grads[0];
var image = op.inputs[0];
- var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray());
+ var shape = new TensorShape(image.shape.dims.Skip(1).Take(2).ToArray());
Tensor image_shape = null;
if (shape.is_fully_defined())
- image_shape = constant_op.constant(image.shape.Skip(1).Take(2).ToArray());
+ image_shape = constant_op.constant(image.shape.dims.Skip(1).Take(2).ToArray());
else
image_shape = array_ops.shape(image)["1:3"];
diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs
index 34710f70..4eb1087e 100644
--- a/src/TensorFlowNET.Core/Gradients/math_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs
@@ -195,7 +195,7 @@ namespace Tensorflow.Gradients
if (op is EagerOperation op_eager &&
op_eager.SkipInputIndices.Contains(1) &&
- y.NDims == 0)
+ y.ndim == 0)
{
return new Tensor[]
{
@@ -759,7 +759,7 @@ namespace Tensorflow.Gradients
if (op is EagerOperation op_eager &&
op_eager.SkipInputIndices.Contains(1) &&
- y.NDims == 0)
+ y.ndim == 0)
{
x = math_ops.conj(x);
y = math_ops.conj(y);
diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs
index 97629062..b170d90b 100644
--- a/src/TensorFlowNET.Core/NumPy/Axis.cs
+++ b/src/TensorFlowNET.Core/NumPy/Axis.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
namespace Tensorflow
@@ -22,6 +23,12 @@ namespace Tensorflow
public static implicit operator Axis(int[] axis)
=> new Axis(axis);
+
+ public static implicit operator Axis(long[] shape)
+ => new Axis(shape.Select(x => (int)x).ToArray());
+
+ public static implicit operator Axis(Shape shape)
+ => new Axis(shape.dims.Select(x => (int)x).ToArray());
}
}
diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs
index 1cfc9b4e..719dba77 100644
--- a/src/TensorFlowNET.Core/Numpy/NDArray.cs
+++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs
@@ -11,8 +11,9 @@ namespace Tensorflow.NumPy
Tensor _tensor;
public TF_DataType dtype => _tensor.dtype;
public ulong size => _tensor.size;
- public ulong dtypesize => _tensor.itemsize;
- public int ndim => _tensor.NDims;
+ public ulong dtypesize => _tensor.dtypesize;
+ public ulong bytesize => _tensor.bytesize;
+ public int ndim => _tensor.ndim;
public long[] dims => _tensor.dims.Select(x => Convert.ToInt64(x)).ToArray();
public Shape shape => _tensor.shape;
public IntPtr data => _tensor.TensorDataPointer;
diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs
index 961955dd..c0b6048d 100644
--- a/src/TensorFlowNET.Core/Numpy/Shape.cs
+++ b/src/TensorFlowNET.Core/Numpy/Shape.cs
@@ -48,6 +48,12 @@ namespace Tensorflow
public static implicit operator Shape((long, long, long, long) dims)
=> new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
+ public static implicit operator int[](Shape shape)
+ => shape.dims.Select(x => (int)x).ToArray();
+
+ public static implicit operator long[](Shape shape)
+ => shape.dims;
+
public bool IsEmpty => size == 0;
public bool IsScalar => ndim == 0;
@@ -55,6 +61,8 @@ namespace Tensorflow
public static Shape Scalar
=> new Shape(new long[0]);
+ public long this[int n] => dims[n];
+
///
/// Returns the size this shape represents.
///
@@ -81,6 +89,25 @@ namespace Tensorflow
}
}
+ public bool is_fully_defined()
+ {
+ return ndim > -1 && dims != null && dims.Count(x => x < 1) == 0;
+ }
+
+ public bool is_compatible_with(TensorShape shape2)
+ {
+ if (dims != null && shape2.dims != null)
+ {
+ if (dims.Contains(-1) || shape2.dims.Contains(-1))
+ return true;
+
+ if (size != (ulong)shape2.size)
+ return false;
+ }
+
+ return true;
+ }
+
public override bool Equals(object obj)
{
if(obj is Shape shape)
diff --git a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs
index 3e185c49..a73bbcc0 100644
--- a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs
+++ b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs
@@ -92,7 +92,7 @@ namespace Tensorflow
public Tensor _batch_shape()
{
- return array_ops.broadcast_static_shape(new Tensor(_loc.shape), new Tensor(_scale.shape));
+ return array_ops.broadcast_static_shape(new Tensor(_loc.shape.dims), new Tensor(_scale.shape.dims));
}
protected override Tensor _log_prob(Tensor x)
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
index cf5f1ce0..c76d768d 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
@@ -27,9 +27,9 @@ namespace Tensorflow.Operations
{
var p = prefix;
var p_static = tensor_util.constant_value(prefix);
- if (p.NDims == 0)
+ if (p.ndim == 0)
p = array_ops.expand_dims(p, 0);
- else if (p.NDims != 1)
+ else if (p.ndim != 1)
throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}");
var s_tensor_shape = new TensorShape(suffix);
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index 9e7290ed..13db8194 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -186,7 +186,7 @@ namespace Tensorflow
private static Tensor _constant_if_small(int value, Tensor shape)
{
- return shape < 1000L;
+ return shape < 1000UL;
}
private static Tensor _constant_if_small(T value, TensorShape shape, TF_DataType dtype, string name)
@@ -330,7 +330,7 @@ namespace Tensorflow
{
name = scope;
var input_tensor = ops.convert_to_tensor(inputs);
- return constant_op.constant(input_tensor.NDims, dtype: tf.int32, name: name);
+ return constant_op.constant(input_tensor.ndim, dtype: tf.int32, name: name);
});
}
@@ -340,7 +340,7 @@ namespace Tensorflow
{
name = scope;
var input_tensor = ops.convert_to_tensor(input);
- var input_shape = tensor_util.to_shape(input_tensor.shape);
+ var input_shape = input_tensor.shape;
if (optimize && input_shape.ndim > 0)
return constant_op.constant(input_shape.ndim, dtype: tf.int32, name: name);
else
@@ -364,7 +364,7 @@ namespace Tensorflow
tensor = ops.convert_to_tensor(tensor, name: "tensor");
// is_fully_defined return unexpected value.
- if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
+ if (optimize && tensor.shape.is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
{
}
@@ -589,9 +589,9 @@ namespace Tensorflow
if (!tf.Context.executing_eagerly())
{
var input_shape = input.TensorShape;
- if (optimize && input.NDims > -1 && input_shape.is_fully_defined())
+ if (optimize && input.ndim > -1 && input_shape.is_fully_defined())
{
- var nd = np.array(input.shape).astype(out_type.as_system_dtype());
+ var nd = np.array(input.shape.dims).astype(out_type.as_system_dtype());
return constant_op.constant(nd, name: name);
}
}
@@ -607,7 +607,7 @@ namespace Tensorflow
name = scope;
var input_tensor = ops.convert_to_tensor(input);
- var input_shape = tensor_util.to_shape(input_tensor.shape);
+ var input_shape = input_tensor.shape;
if (optimize)
{
if (input_shape.is_fully_defined())
@@ -633,7 +633,7 @@ namespace Tensorflow
tensor = ops.convert_to_tensor(tensor, name: "tensor");
// is_fully_defined return unexpected value.
- if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
+ if (optimize && tensor.shape.is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
{
}
@@ -933,7 +933,7 @@ namespace Tensorflow
string name = "split")
{
if (num == -1)
- num = size_splits.shape[0];
+ num = (int)size_splits.shape[0];
return gen_array_ops.split_v(value, size_splits, axis, num, name: name);
}
diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs
index 67450c74..003b93da 100644
--- a/src/TensorFlowNET.Core/Operations/functional_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs
@@ -91,7 +91,7 @@ namespace Tensorflow
elem.dtype,
size: tf.constant(n),
dynamic_size: false,
- element_shape: elem.shape.Skip(1).ToArray(),
+ element_shape: elem.shape.dims.Skip(1).ToArray(),
infer_shape: true)).ToList();
for (int index = 0; index < elems_ta.Count; index++)
diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
index 849a93c8..917dbd6b 100644
--- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
+++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
@@ -341,14 +341,14 @@ or rank = 4. Had rank = {0}", rank));
{
h = _get_dim(image, 0); // img_h == h[0], dynamic_h == h[1]
w = _get_dim(image, 1);
- d = image.shape[3];
+ d = (int)image.shape[3];
}
else
{
- bs = image.shape[0];
+ bs = (int)image.shape[0];
h = _get_dim(image, 1);
w = _get_dim(image, 2);
- d = image.shape[3];
+ d = (int)image.shape[3];
}
object hd, bbox_h_start;
@@ -1115,7 +1115,7 @@ new_height, new_width");
array_ops.expand_dims(tf.constant(3), 0));
var multiples = array_ops.concat(new Tensor[] { shape_list }, 0);
var rgb = array_ops.tile(images, multiples, name: name);
- int[] rgb_temp = images.shape.Take(images.shape.Length - 1).ToArray();
+ int[] rgb_temp = images.shape.dims.Take(images.shape.ndim - 1).Select(x => (int)x).ToArray();
rgb.set_shape(array_ops.concat(new Tensor[] { ops.convert_to_tensor(rgb_temp) }, 3));
return rgb;
});
@@ -1459,7 +1459,7 @@ new_height, new_width");
// shape takes an int, python code passes size, a Tensor. NDims is the only int type
// i could think of a Tensor having. it might be incorrect tho, so keep that in mind.
- return array_ops.reshape(g, shape: new int[] { size.NDims, size.NDims, 1, 1 });
+ return array_ops.reshape(g, shape: new int[] { size.ndim, size.ndim, 1, 1 });
}
internal static (Tensor, Tensor) _ssim_per_channel(Tensor img1, Tensor img2, float max_val = 1f,
@@ -1487,7 +1487,7 @@ new_height, new_width");
img1 = array_ops.identity(img1);
var kernel = _fspecial_gauss(filter_size_tensor, filter_sigma_tensor);
- kernel = array_ops.tile(kernel, multiples: new Tensor(new int[] { 1, 1, shape1_tensor.dims[shape1_tensor.dims.Length - 2], 1 }));
+ kernel = array_ops.tile(kernel, multiples: new Tensor(new int[] { 1, 1, (int)shape1_tensor.dims[shape1_tensor.dims.Length - 2], 1 }));
float compensation = 1.0f;
@@ -1503,8 +1503,8 @@ new_height, new_width");
(Tensor luminance, Tensor cs) = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2);
var axes = constant_op.constant(new[] { -3, -2 }, dtype: dtypes.int32);
- var ssim_val = math_ops.reduce_mean(luminance * cs, new(axes.dims));
- cs = math_ops.reduce_mean(cs, new(axes.dims));
+ var ssim_val = math_ops.reduce_mean(luminance * cs, axes.dims);
+ cs = math_ops.reduce_mean(cs, axes.dims);
return (ssim_val, cs);
}
@@ -1685,7 +1685,7 @@ new_height, new_width");
var kernels_tf = constant_op.constant(kernels, dtype: image.dtype);
kernels_tf = array_ops.tile(
- kernels_tf, new Tensor(new int[] { 1, 1, image_shape.dims[image_shape.dims.Length - 2], 1 }), name: "sobel_filters");
+ kernels_tf, new Tensor(new int[] { 1, 1, (int)image_shape.dims[image_shape.dims.Length - 2], 1 }), name: "sobel_filters");
var pad_sizes = new int[,] { { 0, 0 }, { 1, 1 }, { 1, 1 }, { 0, 0 } };
var padded = array_ops.pad(image, new Tensor(pad_sizes), mode: "reflect");
@@ -1966,8 +1966,8 @@ new_height, new_width");
Tensor index_offsets, indices, sorted_scores, sorted_boxes, sorted_scores_indices;
using (ops.name_scope("sort_scores_and_boxes"))
{
- batch_size = array_ops.shape(boxes).dims[0];
- num_boxes = array_ops.shape(boxes).dims[1];
+ batch_size = (int)array_ops.shape(boxes).dims[0];
+ num_boxes = (int)array_ops.shape(boxes).dims[1];
sorted_scores_indices = null; /*sort_ops.argsort(
scores, axis: 1, direction: "DESCENDING); */
index_offsets = math_ops.range(batch_size) * num_boxes;
diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs
index ef50f69f..6d69a55f 100644
--- a/src/TensorFlowNET.Core/Operations/nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs
@@ -178,7 +178,7 @@ namespace Tensorflow
logits = ops.convert_to_tensor(logits);
var shape = logits.shape;
- bool is_last_dim = dim == -1 || dim == shape.Length - 1;
+ bool is_last_dim = dim == -1 || dim == shape.ndim - 1;
if (is_last_dim)
return compute_op(logits, name);
diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
index 6a52397a..b1dbf586 100644
--- a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
+++ b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
@@ -37,7 +37,7 @@ namespace Tensorflow
{
get
{
- return _row_splits.shape[0] - 1;
+ return (int)_row_splits.shape[0] - 1;
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
index 1f839ee7..d4419073 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
@@ -145,6 +145,10 @@ namespace Tensorflow
byte[,] val => InitTensor(val, shape, dtype),
byte[,,] val => InitTensor(val, shape, dtype),
byte[,,,] val => InitTensor(val, shape, dtype),
+ short[] val => InitTensor(val, shape, dtype),
+ short[,] val => InitTensor(val, shape, dtype),
+ short[,,] val => InitTensor(val, shape, dtype),
+ short[,,,] val => InitTensor(val, shape, dtype),
int[] val => InitTensor(val, shape, dtype),
int[,] val => InitTensor(val, shape, dtype),
int[,,] val => InitTensor(val, shape, dtype),
@@ -153,6 +157,10 @@ namespace Tensorflow
long[,] val => InitTensor(val, shape, dtype),
long[,,] val => InitTensor(val, shape, dtype),
long[,,,] val => InitTensor(val, shape, dtype),
+ ulong[] val => InitTensor(val, shape, dtype),
+ ulong[,] val => InitTensor(val, shape, dtype),
+ ulong[,,] val => InitTensor(val, shape, dtype),
+ ulong[,,,] val => InitTensor(val, shape, dtype),
float[] val => InitTensor(val, shape, dtype),
float[,] val => InitTensor(val, shape, dtype),
float[,,] val => InitTensor(val, shape, dtype),
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
index ed72d9aa..dd7b8ad6 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
@@ -18,7 +18,7 @@ namespace Tensorflow
if (typeof(T).as_tf_dtype() != dtype)
throw new ArrayTypeMismatchException($"dtype {dtype} mismatch.");
- if (NDims == 0 && size == 1) //is it a scalar?
+ if (ndim == 0 && size == 1) //is it a scalar?
{
unsafe
{
@@ -28,7 +28,7 @@ namespace Tensorflow
//types match, no need to perform cast
var ret = new T[size];
- var len = (long)(size * itemsize);
+ var len = (long)(size * dtypesize);
var src = (T*)buffer;
fixed (T* dst = ret)
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 5166cf81..bf8089de 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -72,17 +72,17 @@ namespace Tensorflow
///
public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
- public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
- public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
+ public ulong dtypesize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
+ public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dtypesize;
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
- public int NDims => rank;
+ public int ndim => rank;
///
/// The name of the device on which this tensor will be produced, or null.
///
public virtual string Device => op.Device;
- public int[] dims => shape;
+ public long[] dims => shape.dims;
///
/// Used for keep other pointer when do implicit operating
@@ -107,7 +107,7 @@ namespace Tensorflow
/// Returns the shape of a tensor.
///
/// https://www.tensorflow.org/api_docs/python/tf/shape
- public int[] shape
+ public Shape shape
{
get
{
@@ -123,7 +123,7 @@ namespace Tensorflow
dims[i] = c_api.TF_Dim(_handle, i);
}
- return dims.Select(x => ((IConvertible)x).ToInt32(CultureInfo.InvariantCulture)).ToArray();
+ return dims;
}
set
@@ -131,7 +131,7 @@ namespace Tensorflow
if (value == null)
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle);
else
- c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, tf.Status.Handle);
+ c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle);
tf.Status.Check(true);
}
@@ -139,10 +139,10 @@ namespace Tensorflow
public int[] _shape_tuple()
{
- return rank < 0 ? null : shape;
+ return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray();
}
- public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape);
+ public TensorShape TensorShape => rank < 0 ? new TensorShape() : shape;
///
/// Keras History: (Layer, (node_index, tensor_index))
diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
index 66b5fd3b..5917439e 100644
--- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
@@ -109,7 +109,8 @@ namespace Tensorflow
var length = shape.size * (ulong)dtype.get_datatype_size();
var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, length);
var tensor = TF_TensorData(handle);
- System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length);
+ if (tensor != IntPtr.Zero)
+ System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length);
return handle;
}
diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs
index cf6c76a2..b69c4477 100644
--- a/src/TensorFlowNET.Core/Tensors/constant_op.cs
+++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs
@@ -124,6 +124,8 @@ namespace Tensorflow
return new EagerTensor(new[] { val }, Shape.Scalar);
case long val:
return new EagerTensor(new[] { val }, Shape.Scalar);
+ case ulong val:
+ return new EagerTensor(new[] { val }, Shape.Scalar);
case float val:
return new EagerTensor(new[] { val }, Shape.Scalar);
case double val:
@@ -146,7 +148,7 @@ namespace Tensorflow
if (shape == null)
return t;
- if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims))
+ if (t.shape.dims.SequenceEqual(shape.dims))
return t;
if (verify_shape)
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index 5ad8bc9b..5a007695 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -127,7 +127,7 @@ namespace Tensorflow
}
else if (values is Tensor tensor && tensor.IsReferencedByNDArray)
{
- var len = tensor.itemsize * tensor.size;
+ var len = tensor.dtypesize * tensor.size;
byte[] bytes = tensor.BufferToArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
}
diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
index 960bf656..60de456f 100644
--- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
+++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
@@ -45,7 +45,7 @@ namespace Tensorflow
var restored_tensor = restored_tensors[0];
return gen_state_ops.assign(op,
restored_tensor,
- validate_shape: restored_shapes == null && tensor_util.to_shape(op.shape).is_fully_defined());
+ validate_shape: restored_shapes == null && op.shape.is_fully_defined());
}
}
}
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
index 3bf4f784..36fdfed2 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -50,7 +50,7 @@ namespace Tensorflow
public Operation Op => _variable.op;
public TF_DataType dtype => _variable.dtype;
- public TensorShape shape => tensor_util.to_shape(_variable.shape);
+ public TensorShape shape => _variable.shape;
public string Device => "";
public string Name => _variable.name;
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index bedca2c1..e1563055 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -297,8 +297,8 @@ namespace Tensorflow.Keras
// x = permute_dimensions(x, [0, 3, 1, 2]);
throw new NotImplementedException("");
- int new_height = original_shape[rows] < 0 ? -1 : original_shape[rows] * height_factor;
- int new_width = original_shape[cols] < 0 ? -1 : original_shape[cols] * width_factor;
+ int new_height = original_shape[rows] < 0 ? -1 : (int)original_shape[rows] * height_factor;
+ int new_width = original_shape[cols] < 0 ? -1 : (int)original_shape[cols] * width_factor;
TensorShape output_shape = data_format == "channels_first" ?
(-1, -1, new_height, new_width) : (-1, new_height, new_width, -1);
@@ -316,7 +316,7 @@ namespace Tensorflow.Keras
{
if(axis < 0)
{
- var rank = tensors[0].NDims;
+ var rank = tensors[0].ndim;
if (rank > -1)
axis += rank;
else
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
index d73dc8b1..fc61aa71 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
{
this.args = args;
_process_tensorlike();
- num_samples = args.X.shape[0];
+ num_samples = (int)args.X.shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f)));
diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
index 6fed2bf3..037703c8 100644
--- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
+++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
@@ -63,8 +63,8 @@ namespace Tensorflow.Keras.Engine
{
var y_t_rank = y_t.rank;
var y_p_rank = y_p.rank;
- var y_t_last_dim = y_t.shape[y_t.shape.Length - 1];
- var y_p_last_dim = y_p.shape[y_p.shape.Length - 1];
+ var y_t_last_dim = y_t.shape[y_t.shape.ndim - 1];
+ var y_p_last_dim = y_p.shape[y_p.shape.ndim - 1];
bool is_binary = y_p_last_dim == 1;
bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1;
diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
index 64723a22..592d2568 100644
--- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Metrics
var y_true_rank = y_true.TensorShape.ndim;
// If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
if (y_true_rank != -1 && y_pred_rank != -1
- && y_true.shape.Length == y_pred.shape.Length)
+ && y_true.shape.ndim == y_pred.shape.ndim)
y_true = array_ops.squeeze(y_true, axis: new[] { -1 });
y_pred = math_ops.argmax(y_pred, -1);
diff --git a/src/TensorFlowNET.Keras/tf.layers.cs b/src/TensorFlowNET.Keras/tf.layers.cs
index b69bbe95..3f5ed01c 100644
--- a/src/TensorFlowNET.Keras/tf.layers.cs
+++ b/src/TensorFlowNET.Keras/tf.layers.cs
@@ -212,13 +212,13 @@ namespace Tensorflow.Keras
string data_format = "channels_last")
{
var input_shape = inputs.shape;
- if (inputs.shape.Length == 0)
+ if (inputs.shape.ndim == 0)
throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()");
var premutation = new List() { 0 };
- if (data_format == "channels_first" && inputs.NDims > 1)
+ if (data_format == "channels_first" && inputs.ndim > 1)
{
- premutation.AddRange(Binding.range(2, inputs.NDims));
+ premutation.AddRange(Binding.range(2, inputs.ndim));
premutation.Add(1);
inputs = array_ops.transpose(inputs, premutation.ToArray());
}
diff --git a/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs b/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs
index d9e4e872..b1fe18b4 100644
--- a/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs
@@ -40,7 +40,7 @@ namespace Tensorflow.Native.UnitTest.Sessions
csession.Run(s);
Tensor outTensor = csession.output_tensor(0);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
- EXPECT_EQ(0, outTensor.NDims);
+ EXPECT_EQ(0, outTensor.ndim);
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.ToArray();
EXPECT_EQ(3 + 2, output_contents[0]);
@@ -61,7 +61,7 @@ namespace Tensorflow.Native.UnitTest.Sessions
outTensor = csession.output_tensor(0);
ASSERT_TRUE(outTensor != IntPtr.Zero);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
- EXPECT_EQ(0, outTensor.NDims); // scalar
+ EXPECT_EQ(0, outTensor.ndim); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
output_contents = outTensor.ToArray();
EXPECT_EQ(-(7 + 2), output_contents[0]);
diff --git a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
index dc588a1a..76ebf209 100644
--- a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
@@ -66,7 +66,7 @@ namespace Tensorflow.Native.UnitTest.Tensors
long[] dims = { 2, 3 };
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
- EXPECT_EQ(2, t.NDims);
+ EXPECT_EQ(2, t.ndim);
EXPECT_EQ((int)dims[0], t.shape[0]);
EXPECT_EQ(num_bytes, t.bytesize);
t.Dispose();
diff --git a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs
index 74b1bb03..1b55508b 100644
--- a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs
+++ b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs
@@ -126,7 +126,7 @@ namespace TensorFlowNET.UnitTest.Basics
{
var x = tf.constant(new[,] { { 1, 2 } });
var neg_x = tf.negative(x);
- Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape));
+ Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 1, 2 }, neg_x.shape.dims));
Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray()));
}
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs
index 8b2260a3..2062dbc3 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs
@@ -145,7 +145,7 @@ namespace TensorFlowNET.UnitTest.Basics
var tensor = tf.constant(nd);
var data = tensor.numpy().ToArray();
- Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3 }, tensor.shape));
+ Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3 }, tensor.shape.dims));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
}
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs
index 0bf506da..902bcdbf 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs
@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var b = tf.Variable(-0.73f, name: "bias");
using var g = tf.GradientTape();
var pred = W * X + b;
- var test = tf.slice(pred, new[] { 0 }, pred.shape);
+ var test = tf.slice(pred, new[] { 0 }, (int[])pred.shape);
var gradients = g.gradient(test, (W, b));
Assert.AreEqual((float)gradients.Item1, 0f);
Assert.AreEqual((float)gradients.Item2, 10f);
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs
index cdc8b51c..8f38f45c 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs
@@ -85,14 +85,14 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
{ { 1 }, { 2 }, { 3 } },
{ { 4 }, { 5 }, { 6 } }
}));
- Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape));
+ Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, a.shape.dims));
var b = tf.constant(new[, ,]
{
{ { 1 }, { 2 }, { 3 } },
{ { 4 }, { 5 }, { 6 } }
});
- Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape));
+ Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, b.shape.dims));
}
[TestMethod]
@@ -103,7 +103,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } });
var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
- Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape));
+ Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims));
}
[TestMethod]
@@ -114,7 +114,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } });
var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
- Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape));
+ Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims));
}
[TestMethod]
@@ -128,7 +128,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var splitValue = tf.split(value, 3, axis: 0);
Assert.AreEqual(3, splitValue.Length);
- Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape));
+ Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 2 }, splitValue[0].shape.dims));
}
#region ones/zeros like