| @@ -67,7 +67,7 @@ namespace Tensorflow | |||||
| T[] ExpandArrayToSize<T>(IList<T> src) | T[] ExpandArrayToSize<T>(IList<T> src) | ||||
| { | { | ||||
| if(src.Count == 0) | |||||
| if (src.Count == 0) | |||||
| { | { | ||||
| return new T[0]; | return new T[0]; | ||||
| } | } | ||||
| @@ -77,7 +77,7 @@ namespace Tensorflow | |||||
| var first_elem = src[0]; | var first_elem = src[0]; | ||||
| var last_elem = src[src.Count - 1]; | var last_elem = src[src.Count - 1]; | ||||
| T[] res = new T[num_elements]; | T[] res = new T[num_elements]; | ||||
| for(long i = 0; i < num_elements; i++) | |||||
| for (long i = 0; i < num_elements; i++) | |||||
| { | { | ||||
| if (i < pre) res[i] = first_elem; | if (i < pre) res[i] = first_elem; | ||||
| else if (i >= num_elements - after) res[i] = last_elem; | else if (i >= num_elements - after) res[i] = last_elem; | ||||
| @@ -121,7 +121,7 @@ namespace Tensorflow | |||||
| $"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes."); | $"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes."); | ||||
| } | } | ||||
| if(values.size == 0) | |||||
| if (values.size == 0) | |||||
| { | { | ||||
| return np.zeros(shape, tensor_dtype); | return np.zeros(shape, tensor_dtype); | ||||
| } | } | ||||
| @@ -135,23 +135,47 @@ namespace Tensorflow | |||||
| TF_DataType.TF_QINT32 | TF_DataType.TF_QINT32 | ||||
| }; | }; | ||||
| private static TOut[,] ConvertArray2D<TIn, TOut>(TIn[,] inputArray, Func<TIn, TOut> converter) | |||||
| private static Array ConvertArray<TOut>(Array inputArray, Func<object, TOut> converter) | |||||
| { | { | ||||
| var rows = inputArray.GetLength(0); | |||||
| var cols = inputArray.GetLength(1); | |||||
| var outputArray = new TOut[rows, cols]; | |||||
| if (inputArray == null) | |||||
| throw new ArgumentNullException(nameof(inputArray)); | |||||
| for (var i = 0; i < rows; i++) | |||||
| var elementType = typeof(TOut); | |||||
| var lengths = new int[inputArray.Rank]; | |||||
| for (var i = 0; i < inputArray.Rank; i++) | |||||
| { | { | ||||
| for (var j = 0; j < cols; j++) | |||||
| { | |||||
| outputArray[i, j] = converter(inputArray[i, j]); | |||||
| } | |||||
| lengths[i] = inputArray.GetLength(i); | |||||
| } | } | ||||
| var outputArray = Array.CreateInstance(elementType, lengths); | |||||
| FillArray(inputArray, outputArray, converter, new int[inputArray.Rank], 0); | |||||
| return outputArray; | return outputArray; | ||||
| } | } | ||||
| private static void FillArray<TIn, TOut>(Array inputArray, Array outputArray, Func<TIn, TOut> converter, int[] indices, int dimension) | |||||
| { | |||||
| if (dimension == inputArray.Rank - 1) | |||||
| { | |||||
| for (int i = 0; i < inputArray.GetLength(dimension); i++) | |||||
| { | |||||
| indices[dimension] = i; | |||||
| var inputValue = (TIn)inputArray.GetValue(indices); | |||||
| var convertedValue = converter(inputValue); | |||||
| outputArray.SetValue(convertedValue, indices); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| for (int i = 0; i < inputArray.GetLength(dimension); i++) | |||||
| { | |||||
| indices[dimension] = i; | |||||
| FillArray(inputArray, outputArray, converter, indices, dimension + 1); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a TensorProto, invoked in graph mode | /// Create a TensorProto, invoked in graph mode | ||||
| /// </summary> | /// </summary> | ||||
| @@ -171,24 +195,30 @@ namespace Tensorflow | |||||
| var origin_dtype = values.GetDataType(); | var origin_dtype = values.GetDataType(); | ||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| dtype = origin_dtype; | dtype = origin_dtype; | ||||
| else if(origin_dtype != dtype) | |||||
| else if (origin_dtype != dtype) | |||||
| { | { | ||||
| var new_system_dtype = dtype.as_system_dtype(); | var new_system_dtype = dtype.as_system_dtype(); | ||||
| values = values switch | |||||
| if (dtype != TF_DataType.TF_STRING && dtype != TF_DataType.TF_VARIANT && dtype != TF_DataType.TF_RESOURCE) | |||||
| { | |||||
| if (values is Array arrayValues) | |||||
| { | |||||
| values = dtype switch | |||||
| { | |||||
| TF_DataType.TF_INT32 => ConvertArray(arrayValues, Convert.ToInt32), | |||||
| TF_DataType.TF_FLOAT => ConvertArray(arrayValues, Convert.ToSingle), | |||||
| TF_DataType.TF_DOUBLE => ConvertArray(arrayValues, Convert.ToDouble), | |||||
| _ => values, | |||||
| }; | |||||
| } else | |||||
| { | |||||
| values = Convert.ChangeType(values, new_system_dtype); | |||||
| } | |||||
| } else | |||||
| { | { | ||||
| long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(), | |||||
| long[] longValues => values, | |||||
| float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(), | |||||
| float[] floatValues => values, | |||||
| float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble), | |||||
| float[,] float2DValues => values, | |||||
| double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(), | |||||
| double[] doubleValues => values, | |||||
| double[,] double2DValues when dtype == TF_DataType.TF_FLOAT => ConvertArray2D(double2DValues, Convert.ToSingle), | |||||
| double[,] double2DValues => values, | |||||
| _ => Convert.ChangeType(values, new_system_dtype), | |||||
| }; | |||||
| } | |||||
| dtype = values.GetDataType(); | dtype = values.GetDataType(); | ||||
| } | } | ||||
| @@ -306,7 +336,7 @@ namespace Tensorflow | |||||
| if (tensor is EagerTensor eagerTensor) | if (tensor is EagerTensor eagerTensor) | ||||
| { | { | ||||
| if(tensor.dtype == tf.int64) | |||||
| if (tensor.dtype == tf.int64) | |||||
| return new Shape(tensor.ToArray<long>()); | return new Shape(tensor.ToArray<long>()); | ||||
| else | else | ||||
| return new Shape(tensor.ToArray<int>()); | return new Shape(tensor.ToArray<int>()); | ||||
| @@ -481,7 +511,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
| var d_ = new int[value.size]; | var d_ = new int[value.size]; | ||||
| foreach (var (index, d) in enumerate(value.ToArray<int>())) | foreach (var (index, d) in enumerate(value.ToArray<int>())) | ||||
| d_[index] = d >= 0 ? d : -1; | d_[index] = d >= 0 ? d : -1; | ||||
| ret = ret.merge_with(new Shape(d_)); | ret = ret.merge_with(new Shape(d_)); | ||||
| } | } | ||||
| return ret; | return ret; | ||||