diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index 936d06fe..f4176e38 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -245,7 +245,7 @@ namespace Tensorflow -1 => "", 0 => "()", 1 => $"({dims[0]},)", - _ => $"{string.Join(", ", _dims).Replace("-1", "None")}" + _ => $"({string.Join(", ", _dims).Replace("-1", "None")})" }; } } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 98060436..f6effcf8 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -110,7 +110,21 @@ namespace Tensorflow if (values is TensorProto tp) return tp; - dtype = values.GetDataType(); + var origin_dtype = values.GetDataType(); + if (dtype == TF_DataType.DtInvalid) + dtype = origin_dtype; + else if(origin_dtype != dtype) + { + var new_system_dtype = dtype.as_system_dtype(); + if (values is long[] long_values) + { + if (dtype == TF_DataType.TF_INT32) + values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray(); + } + else + values = Convert.ChangeType(values, new_system_dtype); + } + shape = shape ?? values.GetShape(); var tensor_proto = new TensorProto { diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index ef4c1506..56e9fffa 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -123,6 +123,9 @@ namespace Tensorflow if (dtype == TF_DataType.DtInvalid) dtype = preferred_dtype; + if (dtype == TF_DataType.DtInvalid) + dtype = value.GetDataType(); + if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode) { if (tf.executing_eagerly()) @@ -173,8 +176,7 @@ namespace Tensorflow if (dtype == TF_DataType.TF_STRING) return ret; - var original_dtype = value.GetDataType(); - if (dtype != TF_DataType.DtInvalid && dtype != original_dtype) + if (dtype != ret.dtype) ret = gen_math_ops.cast(ret, dtype.as_base_dtype(), name: name); return ret;