| @@ -0,0 +1,31 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.NumPy | |||||
| { | |||||
| public partial class NumPyImpl | |||||
| { | |||||
| public NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) | |||||
| { | |||||
| var dtype = NumPyUtils.GetResultType(a.dtype, np.float64); | |||||
| if(weights is null) | |||||
| { | |||||
| var tensorA = math_ops.cast(a, dtype); | |||||
| var nd = math_ops.reduce_mean(tensorA, axis); | |||||
| return new NDArray(nd); | |||||
| } | |||||
| else | |||||
| { | |||||
| var tensorW = math_ops.cast(weights, dtype); | |||||
| if(a.rank != weights.rank) | |||||
| { | |||||
| var weights_sum = math_ops.reduce_sum(tensorW); | |||||
| var axes = ops.convert_to_tensor(new[,] { { axis }, { 0 } }); | |||||
| var avg = math_ops.tensordot(a, weights, axes) / weights_sum; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -14,5 +14,9 @@ namespace Tensorflow.NumPy | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis)); | public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis)); | ||||
| [AutoNumPy] | |||||
| public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) | |||||
| => tf.numpy.average(a, axis: axis, weights: weights, returned: returned); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,19 @@ | |||||
| using System; | |||||
| using System.Text; | |||||
| namespace Tensorflow.NumPy | |||||
| { | |||||
| internal class NumPyUtils | |||||
| { | |||||
| public static TF_DataType GetResultType(params TF_DataType[] dtypes) | |||||
| { | |||||
| var resultDType = dtypes[0]; | |||||
| for(int i = 1; i < dtypes.Length; i++) | |||||
| { | |||||
| if (dtypes[i].get_datatype_size() > resultDType.get_datatype_size()) | |||||
| resultDType = dtypes[i]; | |||||
| } | |||||
| return resultDType; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -929,6 +929,72 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("tensordot"); | throw new NotImplementedException("tensordot"); | ||||
| } | } | ||||
| public static Tensor tensordot(Tensor x, Tensor y, Tensor axes, string name = null) | |||||
| { | |||||
| Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) | |||||
| { | |||||
| if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List<object>), typeof(Tuple)))) | |||||
| { | |||||
| var shape_a = a.shape.dims; | |||||
| // axes | |||||
| int iter = 0; | |||||
| foreach (int i in axes) | |||||
| { | |||||
| if (i >= 0) | |||||
| axes[0 + iter] = i; | |||||
| else | |||||
| axes[0 + iter] = i + len(shape_a); | |||||
| iter++; | |||||
| } | |||||
| // free | |||||
| int[] free = { }; | |||||
| iter = 0; | |||||
| foreach (int i in Enumerable.Range(0, len(axes))) | |||||
| if (!Array.Exists(axes, i => i == i)) | |||||
| free[free.Length] = i; | |||||
| // free_dims | |||||
| int[] free_dims = { }; | |||||
| foreach (int i in free) | |||||
| free_dims[free_dims.Length] = (int)shape_a[i]; | |||||
| int prod_free = (int)np.prod(free_dims); | |||||
| // prod_axes | |||||
| int[] prod_axes_pre = { }; | |||||
| foreach (int i in axes) | |||||
| prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i]; | |||||
| int prod_axes = (int)np.prod(prod_axes_pre); | |||||
| // perm | |||||
| Tensor perm; | |||||
| if (flipped) | |||||
| perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free); | |||||
| else | |||||
| perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free) | |||||
| + ops.convert_to_tensor(list(axes)); | |||||
| // new_shape | |||||
| Shape new_shape; | |||||
| if (flipped) | |||||
| new_shape = new Shape(new int[] { prod_axes, prod_free }); | |||||
| else | |||||
| new_shape = new Shape(new int[] { prod_free, prod_axes }); | |||||
| } | |||||
| throw new NotImplementedException("_tensordot_reshape"); | |||||
| } | |||||
| return tf_with(ops.name_scope(name, "Tensordot", new { x, y, axes }), scope => | |||||
| { | |||||
| name = scope; | |||||
| var (a_axes, b_axes) = (axes[0], axes[1]); | |||||
| return x; | |||||
| }); | |||||
| } | |||||
| public static Tensor truediv(Tensor x, Tensor y, string name = null) | public static Tensor truediv(Tensor x, Tensor y, string name = null) | ||||
| => _truediv_python3(x, y, name); | => _truediv_python3(x, y, name); | ||||
| @@ -78,7 +78,7 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// The name of the device on which this tensor will be produced, or null. | /// The name of the device on which this tensor will be produced, or null. | ||||
| /// </summary> | /// </summary> | ||||
| public virtual string Device => op.Device; | |||||
| public virtual string Device => op?.Device; | |||||
| public long[] dims => shape.dims; | public long[] dims => shape.dims; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -0,0 +1,32 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| using Tensorflow.NumPy; | |||||
| namespace TensorFlowNET.UnitTest.NumPy | |||||
| { | |||||
| /// <summary> | |||||
| /// https://numpy.org/doc/stable/reference/routines.statistics.html | |||||
| /// </summary> | |||||
| [TestClass] | |||||
| public class StatisticsTest : EagerModeTestBase | |||||
| { | |||||
| [TestMethod] | |||||
| public void average() | |||||
| { | |||||
| var data = np.arange(1, 5); | |||||
| var avg = np.average(data); | |||||
| Assert.AreEqual(avg, 2.5); | |||||
| data = np.arange(6).reshape((3, 2)); | |||||
| avg = np.average(data, axis: 1); | |||||
| assertAllEqual(avg.ToArray<double>(), new[] { 0.5, 2.5, 4.5 }); | |||||
| // avg = np.average(data, axis: 1, weights: new[] { 1.0 / 4, 3.0 / 4 }); | |||||
| // assertAllEqual(avg.ToArray<double>(), new[] { 0.75, 2.75, 4.75 }); | |||||
| } | |||||
| } | |||||
| } | |||||