From 053cab61a9c29d05127aea31f55f198819c6e586 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 9 Apr 2019 19:11:00 -0500 Subject: [PATCH] condcontext --- .../Operations/ControlFlows/CondContext.cs | 2 ++ .../Operations/control_flow_ops.py.cs | 16 ++++++++++------ .../Operations/embedding_ops.cs | 11 ++++++++--- src/TensorFlowNET.Core/Operations/nn_impl.py.cs | 2 +- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 8ed46036..0385341a 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -75,6 +75,8 @@ namespace Tensorflow.Operations { case Operation[] results: return (original_result, _BuildCondTensor(results)); + case Tensor tensor: + return (original_result, tensor); case float[] fv: var result = ops.convert_to_tensor(fv[0]); return (original_result, result ); diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index db015443..0b7afced 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -217,20 +217,24 @@ namespace Tensorflow var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); context_f.Exit(); - var res_t_flat = res_t; - var res_f_flat = res_f; + var res_t_flat = new Tensor[] { res_t }; + var res_f_flat = new Tensor[] { res_f }; + + foreach(var (val_x, val_y) in zip(res_t_flat, res_f_flat)) + { - return new Tensor(IntPtr.Zero); - /*var merges = zip(res_f_flat, res_t_flat) + } + + var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) .ToArray(); - merges = _convert_flows_to_tensorarrays(orig_res_t, merges); + merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); - return merges;*/ + return merges[0]; }); } diff --git a/src/TensorFlowNET.Core/Operations/embedding_ops.cs b/src/TensorFlowNET.Core/Operations/embedding_ops.cs index 0983ec77..664cdd51 100644 --- a/src/TensorFlowNET.Core/Operations/embedding_ops.cs +++ b/src/TensorFlowNET.Core/Operations/embedding_ops.cs @@ -51,10 +51,15 @@ namespace Tensorflow ids = ops.convert_to_tensor(ids, name: "ids"); if (np == 1) { - + ops.colocate_with(@params[0]); + var result = _clip(array_ops.gather(@params[0], ids, name: name), ids, max_norm); + return array_ops.identity(result); + } + else + { + // Flatten the ids. There are two cases where we need to do this. + throw new NotImplementedException("_embedding_lookup_and_transform"); } - return array_ops.identity(null); - throw new NotImplementedException("_embedding_lookup_and_transform"); }); } diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index 7050890e..c5a86d31 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -20,7 +20,7 @@ namespace Tensorflow float epsilon = 1e-12f, string name = null) { - return with(ops.name_scope(name, "", new { x }), scope => + return with(ops.name_scope(name, "l2_normalize", new { x }), scope => { x = ops.convert_to_tensor(x, name: "x"); var square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims: true);