Browse Source

added tf.concat interface.

tags/v0.9
Oceania2018 6 years ago
parent
commit
112247f05f
2 changed files with 24 additions and 1 deletions
  1. +16
    -0
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +8
    -1
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

+ 16
- 0
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -1,11 +1,27 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
/// <summary>
/// Concatenates tensors along one dimension.
/// </summary>
/// <param name="values">A list of `Tensor` objects or a single `Tensor`.</param>
/// <param name="axis"></param>
/// <param name="name"></param>
/// <returns>A `Tensor` resulting from concatenation of the input tensors.</returns>
public static Tensor concat(IList<Tensor> values, int axis, string name = "concat")
{
if (values.Count == 1)
throw new NotImplementedException("tf.concat length is 1");

return gen_array_ops.concat_v2(values.ToArray(), axis, name: name);
}

/// <summary>
/// Inserts a dimension of 1 into a tensor's shape.
/// </summary>


+ 8
- 1
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -195,7 +195,14 @@ namespace TensorFlowNET.Examples
pooled_outputs.Add(pool);
}

// var h_pool = tf.concat(pooled_outputs, 3);
var h_pool = tf.concat(pooled_outputs, 3);
var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank));

with(tf.name_scope("dropout"), delegate
{
// var h_drop = tf.nn.dropout(h_pool_flat, self.keep_prob);
});

return graph;
}



Loading…
Cancel
Save