From 060cc37dd47f39746db5962b20277ee6a20f9d44 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 17 Sep 2019 23:03:21 -0500 Subject: [PATCH] tf.sparse_tensor_to_dense, TensorShape.merge_with #396 --- src/TensorFlowNET.Core/APIs/tf.sparse.cs | 13 +- .../Framework/sparse_tensor.py.cs | 21 ++- .../Operations/gen_sparse_ops.cs | 20 +++ src/TensorFlowNET.Core/Tensors/Dimension.cs | 27 +++ src/TensorFlowNET.Core/Tensors/TensorShape.cs | 19 +- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 166 ++++++------------ test/TensorFlowNET.UnitTest/TensorTest.cs | 22 ++- 7 files changed, 157 insertions(+), 131 deletions(-) create mode 100644 src/TensorFlowNET.Core/Tensors/Dimension.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.sparse.cs b/src/TensorFlowNET.Core/APIs/tf.sparse.cs index bb5bc96d..c615a614 100644 --- a/src/TensorFlowNET.Core/APIs/tf.sparse.cs +++ b/src/TensorFlowNET.Core/APIs/tf.sparse.cs @@ -20,9 +20,20 @@ namespace Tensorflow { public partial class tensorflow { - public SparseTensor SparseTensor(long[,] indices, T[] values, int[] dense_shape) + public SparseTensor SparseTensor(long[,] indices, T[] values, long[] dense_shape) => new SparseTensor(indices, values, dense_shape); + public Tensor sparse_tensor_to_dense(SparseTensor sp_input, + T default_value = default, + bool validate_indices = true, + string name = null) + => gen_sparse_ops.sparse_to_dense(sp_input.indices, + sp_input.dense_shape, + sp_input.values, + default_value: default_value, + validate_indices: validate_indices, + name: name); + /// /// Converts a sparse representation into a dense tensor. /// diff --git a/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs b/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs index cd0a2893..b03ce2de 100644 --- a/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs +++ b/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs @@ -1,4 +1,6 @@ -using static Tensorflow.Binding; +using System; +using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow.Framework { @@ -8,15 +10,20 @@ namespace Tensorflow.Framework public class SparseTensor : CompositeTensor, _TensorLike { long[,] _indices; - Tensor indices; + public Tensor indices; T[] _values; - Tensor values; + public Tensor values; - int[] _dense_shape; - Tensor dense_shape; + long[] _dense_shape; + public Tensor dense_shape; - public SparseTensor(long[,] indices_, T[] values_, int[] dense_shape_) + TensorShape _shape; + public TensorShape shape => _shape; + + public TF_DataType dtype => dtypes.as_dtype(typeof(T)); + + public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_) { tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate { @@ -37,6 +44,8 @@ namespace Tensorflow.Framework indices_shape[0].merge_with(values_shape.dims[0]); indices_shape[1].merge_with(dense_shape_shape.dims[0]); + + _shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray()); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs b/src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs index 57a5f860..d59afc88 100644 --- a/src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System.Collections.Generic; +using Tensorflow.Framework; namespace Tensorflow { @@ -50,5 +51,24 @@ namespace Tensorflow return _op.output; } + + public static Tensor sparse_to_dense(Tensor sparse_indices, + Tensor output_shape, + Tensor sparse_values, + T default_value = default, + bool validate_indices = true, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("SparseToDense", name, args: new + { + sparse_indices, + output_shape, + sparse_values, + default_value, + validate_indices + }); + + return _op.output; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Dimension.cs b/src/TensorFlowNET.Core/Tensors/Dimension.cs new file mode 100644 index 00000000..58520270 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Dimension.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class Dimension + { + int _value; + public int value => _value; + + public Dimension(int value) + { + _value = value; + } + + public Dimension merge_with(Dimension other) + { + if (_value == -1) + return new Dimension(other.value); + else + return new Dimension(_value); + } + + public override string ToString() => $"Dimension({_value})"; + } +} diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 3e4deac2..f8417924 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -1,9 +1,10 @@ using NumSharp; using System; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; -using NumSharp.Utilities; +using static Tensorflow.Binding; namespace Tensorflow { @@ -196,12 +197,26 @@ namespace Tensorflow } } + /// + /// Returns a `TensorShape` combining the information in `self` and `other`. + /// + /// + /// public TensorShape merge_with(TensorShape other) { if (dims.Length == 0) return other; - throw new NotImplementedException("merge_with"); + var new_dims = new List(); + + foreach (var i in range(ndim)) + { + var dim = new Dimension(dims[i]); + var merged = dim.merge_with(new Dimension(other.dims[i])); + new_dims.Add(merged.value); + } + + return new TensorShape(new_dims.ToArray()); } /// diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 59c107fc..142afe06 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -118,110 +118,10 @@ namespace Tensorflow if (values == null) throw new ValueError("None values not supported."); - if(np_dt == null) - { - switch (values) - { - case bool boolVal: - nparray = boolVal; - break; - case int intVal: - nparray = intVal; - break; - case int[] intVals: - nparray = np.array(intVals); - break; - case int[,] intVals: - nparray = np.array(intVals); - break; - case long intVal: - nparray = intVal; - break; - case long[] intVals: - nparray = np.array(intVals); - break; - case long[,] intVals: - nparray = np.array(intVals); - break; - case float floatVal: - nparray = floatVal; - break; - case float[] floatVals: - nparray = floatVals; - break; - case float[,] floatVals: - nparray = np.array(floatVals); - break; - case double doubleVal: - nparray = doubleVal; - break; - case double[] doubleVals: - nparray = np.array(doubleVals); - break; - case double[,] doubleVals: - nparray = np.array(doubleVals); - break; - case string strVal: - nparray = strVal; - break; - case string[] strVals: - nparray = strVals; - break; - case byte[] byteValues: - nparray = byteValues; - break; - case byte[,] byteValues: - nparray = np.array(byteValues); - break; - default: - throw new NotImplementedException($"make_tensor_proto: Support for type {values.GetType()} Not Implemented"); - } - } - else - { - // convert data type - switch (np_dt.Name) - { - case "Int32": - if (values.GetType().IsArray) - nparray = np.array((int[])values, np_dt); - else - nparray = Converts.ToInt32(values); - break; - case "Int64": - if (values.GetType().IsArray) - nparray = np.array((int[])values, np_dt); - else - nparray = Converts.ToInt64(values); - break; - case "Single": - if (values.GetType().IsArray) - nparray = np.array((float[])values, np_dt); - else - nparray = Converts.ToSingle(values); - break; - case "Double": - if (values.GetType().IsArray) - nparray = np.array((double[])values, np_dt); - else - nparray = Converts.ToDouble(values); - break; - case "String": - if (values.GetType().IsArray) - nparray = np.array((string[])values, np_dt); - else - nparray = NDArray.FromString(Converts.ToString(values)); - break; - case "Boolean": - if (values.GetType().IsArray) - nparray = np.array((bool[])values, np_dt); - else - nparray = Converts.ToBoolean(values); - break; - default: - throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); - } - } + nparray = convert_to_numpy_ndarray(values); + + if (np_dt != null && np_dt != typeof(string)) + nparray = nparray.astype(np_dt); } var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype); @@ -316,23 +216,59 @@ namespace Tensorflow case NDArray val: nd = val; break; - case int val: - nd = np.asarray(val); + case bool boolVal: + nd = boolVal; + break; + case int intVal: + nd = intVal; + break; + case int[] intVals: + nd = np.array(intVals); + break; + case int[,] intVals: + nd = np.array(intVals); + break; + case long intVal: + nd = intVal; + break; + case long[] intVals: + nd = np.array(intVals); + break; + case long[,] intVals: + nd = np.array(intVals); + break; + case float floatVal: + nd = floatVal; + break; + case float[] floatVals: + nd = floatVals; + break; + case float[,] floatVals: + nd = np.array(floatVals); + break; + case double doubleVal: + nd = doubleVal; + break; + case double[] doubleVals: + nd = np.array(doubleVals); + break; + case double[,] doubleVals: + nd = np.array(doubleVals); break; - case int[] val: - nd = np.array(val); + case string strVal: + nd = NDArray.FromString(strVal); break; - case float val: - nd = np.asarray(val); + case string[] strVals: + nd = strVals; break; - case double val: - nd = np.asarray(val); + case byte[] byteValues: + nd = byteValues; break; - case string val: - nd = np.asarray(val); + case byte[,] byteValues: + nd = np.array(byteValues); break; default: - throw new Exception("Not Implemented"); + throw new NotImplementedException($"convert_to_numpy_ndarray: Support for type {values.GetType()} Not Implemented"); } return nd; diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 5f6a4b40..6b5b5dec 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -225,14 +225,22 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void sparse_tensor_to_dense() { - /*int[,] dense_array = + var decoded_list = tf.SparseTensor(new[,] { - { 1, 0, 0, 0, 0 }, - { 0, 1, 0, 0, 0 }, - { 0, 0, 1, 0, 0 }, - { 0, 0, 0, 1, 0 } - }; - var sparseTensor = new SparseTensor(indices, values, dense_shape);*/ + { 0L, 0L }, + { 1L, 2L } + }, + new int[] { 1, 2 }, + new[] { 3L, 4L }); + + var onehot = tf.sparse_tensor_to_dense(decoded_list); + using (var sess = tf.Session()) + { + var result = sess.run(onehot); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray())); + } } } } \ No newline at end of file