From 5a73e698b0d02bf62480a539a25c1f35e2721aa0 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 29 May 2019 09:34:04 -0500 Subject: [PATCH] add Tensor[] pattern match for ops.name_scope. --- src/TensorFlowNET.Core/Layers/Layer.cs | 2 +- src/TensorFlowNET.Core/ops.name_scope.cs | 8 +++-- src/TensorFlowNET.Core/ops.py.cs | 4 ++- .../TextProcess/TextClassificationTrain.cs | 31 ------------------- 4 files changed, 10 insertions(+), 35 deletions(-) diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index eea80e04..aad687e6 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -37,7 +37,7 @@ namespace Tensorflow.Layers VariableScope scope = null) { _set_scope(scope); - _graph = ops._get_graph_from_inputs(new List { inputs }, graph: _graph); + _graph = ops._get_graph_from_inputs(new Tensor[] { inputs }, graph: _graph); variable_scope scope_context_manager = null; if (built) diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index 03b9349e..abc8bd80 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Eager; @@ -37,8 +38,11 @@ namespace Tensorflow _name = _name == null ? _default_name : _name; Graph g = null; - if (_values is List values) - g = _get_graph_from_inputs(values); + + if (_values is List vList) + g = _get_graph_from_inputs(vList.ToArray()); + else if (_values is Tensor[] vArray) + g = _get_graph_from_inputs(vArray); if (g == null) g = get_default_graph(); diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 53f03375..4babffb0 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -102,8 +102,10 @@ namespace Tensorflow default_graph = tf.Graph(); } + public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) + => _get_graph_from_inputs(op_input_list: op_input_list); - public static Graph _get_graph_from_inputs(List op_input_list, Graph graph = null) + public static Graph _get_graph_from_inputs(Tensor[] op_input_list, Graph graph = null) { foreach(var op_input in op_input_list) { diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index c7652268..38a519d1 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -203,37 +203,6 @@ namespace TensorFlowNET.Examples.CnnTextClassification return (train_x, valid_x, train_y, valid_y); } - //private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f) - //{ - // Console.WriteLine("Splitting in Training and Testing data..."); - // var stopwatch = Stopwatch.StartNew(); - // int len = x.Length; - // int train_size = int.Parse((len * (1 - test_size)).ToString()); - // var random = new Random(17); - - // // we collect indices of labels - // var labels = new Dictionary>(); - // var shuffled_indices = random.Shuffle(range(len).ToArray()); - // foreach (var i in shuffled_indices) - // { - // var label = y[i]; - // if (!labels.ContainsKey(i)) - // labels[label] = new HashSet(); - // labels[label].Add(i); - // } - - // var train_x = new int[train_size][]; - // var valid_x = new int[len - train_size][]; - // var train_y = new int[train_size]; - // var valid_y = new int[len - train_size]; - - // FillWithShuffledLabels(x, y, train_x, train_y, random, labels); - // FillWithShuffledLabels(x, y, valid_x, valid_y, random, labels); - - // Console.WriteLine("\tDONE " + stopwatch.Elapsed); - // return (train_x, valid_x, train_y, valid_y); - //} - private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary> labels) { int i = 0;