| @@ -1,4 +1,4 @@ | |||||
| /***************************************************************************** | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| @@ -135,6 +135,23 @@ namespace Tensorflow | |||||
| TF_DataType.TF_QINT32 | TF_DataType.TF_QINT32 | ||||
| }; | }; | ||||
| private static TOut[,] ConvertArray2D<TIn, TOut>(TIn[,] inputArray, Func<TIn, TOut> converter) | |||||
| { | |||||
| var rows = inputArray.GetLength(0); | |||||
| var cols = inputArray.GetLength(1); | |||||
| var outputArray = new TOut[rows, cols]; | |||||
| for (var i = 0; i < rows; i++) | |||||
| { | |||||
| for (var j = 0; j < cols; j++) | |||||
| { | |||||
| outputArray[i, j] = converter(inputArray[i, j]); | |||||
| } | |||||
| } | |||||
| return outputArray; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a TensorProto, invoked in graph mode | /// Create a TensorProto, invoked in graph mode | ||||
| /// </summary> | /// </summary> | ||||
| @@ -157,19 +174,16 @@ namespace Tensorflow | |||||
| else if(origin_dtype != dtype) | else if(origin_dtype != dtype) | ||||
| { | { | ||||
| var new_system_dtype = dtype.as_system_dtype(); | var new_system_dtype = dtype.as_system_dtype(); | ||||
| if (values is long[] long_values) | |||||
| { | |||||
| if (dtype == TF_DataType.TF_INT32) | |||||
| values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray(); | |||||
| } | |||||
| else if (values is double[] double_values) | |||||
| values = values switch | |||||
| { | { | ||||
| if (dtype == TF_DataType.TF_FLOAT) | |||||
| values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray(); | |||||
| } | |||||
| else | |||||
| values = Convert.ChangeType(values, new_system_dtype); | |||||
| long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(), | |||||
| float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(), | |||||
| float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble), | |||||
| double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(), | |||||
| double[,] double2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(double2DValues, Convert.ToSingle), | |||||
| _ => Convert.ChangeType(values, new_system_dtype), | |||||
| }; | |||||
| dtype = values.GetDataType(); | dtype = values.GetDataType(); | ||||
| } | } | ||||