Browse Source

tf.sparse_to_dense #396

tags/v0.12
Oceania2018 6 years ago
parent
commit
1e5018334e
6 changed files with 217 additions and 8 deletions
  1. +50
    -0
      src/TensorFlowNET.Core/APIs/tf.sparse.cs
  2. +43
    -8
      src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs
  3. +54
    -0
      src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs
  4. +33
    -0
      src/TensorFlowNET.Core/Operations/sparse_ops.cs
  5. +18
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  6. +19
    -0
      test/TensorFlowNET.UnitTest/TensorTest.cs

+ 50
- 0
src/TensorFlowNET.Core/APIs/tf.sparse.cs View File

@@ -0,0 +1,50 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using Tensorflow.Framework;

namespace Tensorflow
{
public partial class tensorflow
{
public SparseTensor<T> SparseTensor<T>(long[,] indices, T[] values, int[] dense_shape)
=> new SparseTensor<T>(indices, values, dense_shape);

/// <summary>
/// Converts a sparse representation into a dense tensor.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="sparse_indices"></param>
/// <param name="output_shape"></param>
/// <param name="sparse_values"></param>
/// <param name="default_value"></param>
/// <param name="validate_indices"></param>
/// <param name="name"></param>
/// <returns>Dense `Tensor` of shape `output_shape`. Has the same type as `sparse_values`.</returns>
public Tensor sparse_to_dense<T>(Tensor sparse_indices,
TensorShape output_shape,
T sparse_values,
T default_value = default,
bool validate_indices = true,
string name = null)
=> gen_sparse_ops.sparse_to_dense(sparse_indices,
output_shape,
sparse_values,
default_value: default_value,
validate_indices: validate_indices,
name: name);
}
}

+ 43
- 8
src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs View File

@@ -1,19 +1,54 @@
namespace Tensorflow.Framework
{
public interface _TensorLike
{ }
using static Tensorflow.Binding;

public class SparseTensor : CompositeTensor, _TensorLike
namespace Tensorflow.Framework
{
/// <summary>
/// Represents a sparse tensor.
/// </summary>
public class SparseTensor<T> : CompositeTensor, _TensorLike
{
private static Tensor _dense_shape { get; set; }
long[,] _indices;
Tensor indices;

T[] _values;
Tensor values;

int[] _dense_shape;
Tensor dense_shape;

public SparseTensor(long[,] indices_, T[] values_, int[] dense_shape_)
{
tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate
{
indices = ops.convert_to_tensor(
indices_, name: "indices", dtype: dtypes.int64);
values = ops.internal_convert_to_tensor(values_, name: "values");
dense_shape = ops.convert_to_tensor(
dense_shape_, name: "dense_shape", dtype: dtypes.int64);
});

_indices = indices_;
_values = values_;
_dense_shape = dense_shape_;

var indices_shape = indices.TensorShape.with_rank(2);
var values_shape = values.TensorShape.with_rank(1);
var dense_shape_shape = dense_shape.TensorShape.with_rank(1);

indices_shape[0].merge_with(values_shape.dims[0]);
indices_shape[1].merge_with(dense_shape_shape.dims[0]);
}
}

public interface _TensorLike
{
}

public static class sparse_tensor
public static class sparse_tensor_extension
{
public static bool is_sparse(this _TensorLike x)
{
return x is SparseTensor;
return x.GetType().Name.Contains("SparseTensor");
}
}
}

+ 54
- 0
src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs View File

@@ -0,0 +1,54 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System.Collections.Generic;

namespace Tensorflow
{
public class gen_sparse_ops
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();

/// <summary>
/// Converts a sparse representation into a dense tensor.
/// </summary>
/// <param name="sparse_indices"></param>
/// <param name="output_shape"></param>
/// <param name="sparse_values"></param>
/// <param name="default_value"></param>
/// <param name="validate_indices"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor sparse_to_dense<T>(Tensor sparse_indices,
int[] output_shape,
T sparse_values,
T default_value,
bool validate_indices = true,
string name = null)
{
var _op = _op_def_lib._apply_op_helper("SparseToDense", name, args: new
{
sparse_indices,
output_shape,
sparse_values,
default_value,
validate_indices
});

return _op.output;
}
}
}

+ 33
- 0
src/TensorFlowNET.Core/Operations/sparse_ops.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class sparse_ops
{
/// <summary>
/// Converts a sparse representation into a dense tensor.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="sparse_indices"></param>
/// <param name="output_shape"></param>
/// <param name="sparse_values"></param>
/// <param name="default_value"></param>
/// <param name="validate_indices"></param>
/// <param name="name"></param>
/// <returns>Dense `Tensor` of shape `output_shape`. Has the same type as `sparse_values`.</returns>
public Tensor sparse_to_dense<T>(Tensor sparse_indices,
int[] output_shape,
T sparse_values,
T default_value = default,
bool validate_indices = true,
string name = null)
=> gen_sparse_ops.sparse_to_dense(sparse_indices,
output_shape,
sparse_values,
default_value: default_value,
validate_indices: validate_indices,
name: name);
}
}

+ 18
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -143,6 +143,24 @@ namespace Tensorflow
return this;
}

public TensorShape with_rank(int rank)
{
return merge_with(unknown_shape(rank: rank));
}

/// <summary>
/// Returns an unknown TensorShape, optionally with a known rank.
/// </summary>
/// <param name="rank"></param>
/// <returns></returns>
public TensorShape unknown_shape(int rank = -1)
{
if (rank == -1)
return new TensorShape(-1);
else
return new TensorShape(Enumerable.Repeat(-1, rank).ToArray());
}

/// <summary>
/// Returns the concatenation of the dimension in `self` and `other`.
/// </summary>


+ 19
- 0
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -7,6 +7,7 @@ using System.Threading;
using FluentAssertions;
using Tensorflow;
using static Tensorflow.Binding;
using Tensorflow.Framework;

namespace TensorFlowNET.UnitTest
{
@@ -202,5 +203,23 @@ namespace TensorFlowNET.UnitTest
// graph.Dispose();
s.Dispose();
}

[TestMethod]
public void sparse_to_dense()
{
var indices = tf.reshape(tf.range(0, 5), new int[] { 5, 1 });
var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }),1);
var st = tf.concat(values: new[] { indices, labels }, axis: 1);
var onehot = tf.sparse_to_dense(st, (5, 5), 1);
using (var sess = tf.Session())
{
var result = sess.run(onehot);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>()));
};
}
}
}

Loading…
Cancel
Save