Browse Source

Update tensor_util.cs

pull/1217/head
novikov-alexander GitHub 1 year ago
parent
commit
483ac82cd2
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 1 deletions
  1. +6
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs

+ 6
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -178,10 +178,15 @@ namespace Tensorflow
values = values switch values = values switch
{ {
long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(), long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(),
long[] longValues => values,
float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(), float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(),
float[] floatValues => values,
float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble), float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble),
float[,] float2DValues => values,
double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(), double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(),
double[,] double2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(double2DValues, Convert.ToSingle),
double[] doubleValues => values,
double[,] double2DValues when dtype == TF_DataType.TF_FLOAT => ConvertArray2D(double2DValues, Convert.ToSingle),
double[,] double2DValues => values,
_ => Convert.ChangeType(values, new_system_dtype), _ => Convert.ChangeType(values, new_system_dtype),
}; };
dtype = values.GetDataType(); dtype = values.GetDataType();


Loading…
Cancel
Save