From 112247f05f14a707d15111a17b6e72b22627167b Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 5 Jun 2019 06:24:22 -0500 Subject: [PATCH] added tf.concat interface. --- src/TensorFlowNET.Core/APIs/tf.array.cs | 16 ++++++++++++++++ .../TextProcess/CnnTextClassification.cs | 9 ++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index dd385b4c..cbcfed28 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -1,11 +1,27 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow { public static partial class tf { + /// + /// Concatenates tensors along one dimension. + /// + /// A list of `Tensor` objects or a single `Tensor`. + /// + /// + /// A `Tensor` resulting from concatenation of the input tensors. + public static Tensor concat(IList values, int axis, string name = "concat") + { + if (values.Count == 1) + throw new NotImplementedException("tf.concat length is 1"); + + return gen_array_ops.concat_v2(values.ToArray(), axis, name: name); + } + /// /// Inserts a dimension of 1 into a tensor's shape. /// diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index 503e74f6..a6d17a1e 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -195,7 +195,14 @@ namespace TensorFlowNET.Examples pooled_outputs.Add(pool); } - // var h_pool = tf.concat(pooled_outputs, 3); + var h_pool = tf.concat(pooled_outputs, 3); + var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank)); + + with(tf.name_scope("dropout"), delegate + { + // var h_drop = tf.nn.dropout(h_pool_flat, self.keep_prob); + }); + return graph; }