diff --git a/README.md b/README.md
index a80191a7..8744ba72 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@
[](https://tensorflownet.readthedocs.io/en/latest/?badge=latest)
[](https://996.icu/#/en_US)
-TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).
+TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).

diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index c70af1fd..34f227fc 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -30,6 +30,20 @@ namespace Tensorflow
///
public static partial class Binding
{
+ public static T2 get(this Dictionary dict, T1 key)
+ => key == null ?
+ default(T2) :
+ (dict.ContainsKey(key) ? dict[key] : default(T2));
+
+ public static void add(this IList list, T element)
+ => list.Add(element);
+
+ public static void append(this IList list, T element)
+ => list.Add(element);
+
+ public static void extend(this List list, IEnumerable elements)
+ => list.AddRange(elements);
+
private static string _tostring(object obj)
{
switch (obj)
@@ -81,6 +95,9 @@ namespace Tensorflow
throw new NotImplementedException("len() not implemented for type: " + a.GetType());
}
+ public static T[] list(IEnumerable list)
+ => list.ToArray();
+
public static IEnumerable range(int end)
{
return Enumerable.Range(0, end);
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs
index d1be6f31..1d296774 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs
@@ -109,11 +109,12 @@ namespace Tensorflow.Operations.ControlFlows
grad_state.grad_context.Enter();
}
- // def ExitGradWhileContext(self, op, before):
- // """Exit the WhileContext for gradient computation."""
- // grad_state = self.GetGradState(op, before)
- // if grad_state:
- // grad_state.grad_context.Exit()
+ public void ExitGradWhileContext(Operation op, bool before)
+ {
+ var grad_state = GetGradState(op, before);
+ if (grad_state != null)
+ grad_state.grad_context.Exit();
+ }
// def AddWhileContext(self, op, between_op_list, between_ops):
// """Add the grad state for the while loop that op belongs to.
@@ -287,51 +288,9 @@ namespace Tensorflow.Operations.ControlFlows
return result;
}
- // def PostProcessing(self):
- // """Perform postprocessing at the end of gradients().
-
- // We have created the gradient graph at this point. So this function
- // can be used to perform any postprocessing on the gradient graph.
- // We currently perform the following postprocessing:
- // 1. Patch the gradient graph if the output of a loop variable
- // doesn't depend on its input.
- // """
- // for _, grad_state in self._map.items():
- // for _, b_merge in grad_state.switch_map.items():
- // if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
- // # The value of this loop variable at iteration i+1 doesn't
- // # depend on its value at iteration i. So use zeros as the
- // # gradients for all iterations > 0.
- // dtype = b_merge.op.inputs[0].dtype
- // shape = b_merge.op.inputs[0].get_shape()
- // # pylint: disable=protected-access
- // if shape.is_fully_defined():
- // grad_state.grad_context.Enter()
- // # Create a zeros and use it for iterations > 0.
- // grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
- // next_grad_val = _NextIteration(grad_val)
- // grad_state.grad_context.Exit()
- // else:
- // # Create a zeros in the outer grad context.
- // outer_grad_ctxt = grad_state.grad_context.outer_context
- // if outer_grad_ctxt:
- // outer_grad_ctxt.Enter()
- // enter_grad_op = b_merge.op.inputs[0].op
- // enter_grad = enter_grad_op.inputs[0]
- // grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
- // grad_val = array_ops.zeros(grad_shape)
- // if outer_grad_ctxt:
- // outer_grad_ctxt.Exit()
- // # Use the zeros for iterations > 0.
- // grad_state.grad_context.Enter()
- // next_grad_val = _NextIteration(grad_val)
- // grad_state.grad_context.Exit()
- // b_merge.op._update_input(1, next_grad_val)
- // # pylint: enable=protected-access
-
+ public void PostProcessing()
+ {
+ throw new NotImplementedException("PostProcessing");
+ }
}
-
-
-
-
}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
index e17ab8ba..143aacb1 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
@@ -17,7 +17,9 @@
using System;
using System.Collections;
using System.Collections.Generic;
+using System.Linq;
using static Tensorflow.Binding;
+using util = Tensorflow.control_flow_util;
namespace Tensorflow.Operations.ControlFlows
{
@@ -56,6 +58,7 @@ namespace Tensorflow.Operations.ControlFlows
public GradLoopState outer_grad_state => _outer_grad_state;
Tensor _forward_index;
+ public Tensor forward_index => _forward_index;
Tensor _grad_index;
Tensor[] _forward_loop_exits;
@@ -152,63 +155,52 @@ namespace Tensorflow.Operations.ControlFlows
/// The stack that contains the accumulated history of the tensor.
public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
{
- throw new NotImplementedException("AddForwardAccumulator");
- // # curr_ctxt is the context that tf.gradients was called in.
- // with self._forward_index.graph.as_default():
- // curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
- // with ops.control_dependencies(None):
- // if curr_ctxt:
- // curr_ctxt.Enter()
- // with ops.colocate_with(value):
- // # We only need to pass maximum_iterations to the stack if
- // # we're inside an XLA context.
- // if not util.IsInXLAContext(value.op):
- // max_size = constant_op.constant(-1, dtypes.int32)
- // else:
- // max_size = GetMaxSizeFromNestedMaximumIterations(
- // value, self.forward_context)
- // acc = gen_data_flow_ops.stack_v2(
- // max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
- // if curr_ctxt:
- // curr_ctxt.Exit()
-
- // # Make acc available in the forward context.
- // enter_acc = self.forward_context.AddValue(acc)
-
- // # Add the stack_push op in the context of value.op.
- // swap_enabled = self.forward_context.swap_memory
- // value_ctxt = util.GetOutputContext(value.op)
- // if value_ctxt == self.forward_context:
- // # value is not nested in the forward context.
- // self.forward_context.Enter()
- // push = gen_data_flow_ops.stack_push_v2(
- // enter_acc, value, swap_memory=swap_enabled)
- // self.forward_context.Exit()
- // # Protect stack push and order it before forward_index.
- // self.forward_index.op._add_control_input(push.op)
- // else:
- // # value is in a cond context within the forward context.
- // if not isinstance(value_ctxt, CondContext):
- // raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
- // if dead_branch:
- // # The special case for creating a zero tensor for a dead
- // # branch of a switch. See ControlFlowState.ZerosLike().
- // value_ctxt.outer_context.Enter()
- // push = gen_data_flow_ops.stack_push_v2(
- // enter_acc, value, swap_memory=swap_enabled)
- // value_ctxt.outer_context.Exit()
- // push.op._set_control_flow_context(value_ctxt)
- // else:
- // value_ctxt.Enter()
- // push = gen_data_flow_ops.stack_push_v2(
- // enter_acc, value, swap_memory=swap_enabled)
- // value_ctxt.Exit()
- // # Protect stack push and order it before forward_sync.
- // self.forward_sync._add_control_input(push.op)
- // # Order stack push after the successor of forward_index
- // add_op = self.forward_index.op.inputs[0].op
- // push.op._add_control_input(add_op)
- // return acc
+ using (_forward_index.graph.as_default())
+ {
+ var curr_ctxt = ops.get_default_graph()._get_control_flow_context();
+ return tf_with(ops.control_dependencies(null), delegate
+ {
+ Tensor acc = null;
+ Tensor push = null;
+ if (curr_ctxt != null)
+ curr_ctxt.Enter();
+ ops.colocate_with(value);
+ {
+ // We only need to pass maximum_iterations to the stack if
+ // we're inside an XLA context.
+ var max_size = constant_op.constant(-1, dtypes.int32);
+ acc = gen_data_flow_ops.stack_v2(
+ max_size: max_size, elem_type: value.dtype.as_base_dtype(), name: "f_acc");
+ }
+ if (curr_ctxt != null)
+ curr_ctxt.Exit();
+
+ // Make acc available in the forward context.
+ var enter_acc = forward_context.AddValue(acc);
+
+ // Add the stack_push op in the context of value.op.
+ var swap_enabled = forward_context.swap_memory;
+ var value_ctxt = util.GetOutputContext(value.op);
+ if(value_ctxt == forward_context)
+ {
+ // value is not nested in the forward context.
+ forward_context.Enter();
+ push = gen_data_flow_ops.stack_push_v2(enter_acc, value, swap_memory: swap_enabled);
+ forward_context.Exit();
+ // Protect stack push and order it before forward_index.
+ forward_index.op._add_control_input(push.op);
+ }
+ else
+ {
+ throw new NotImplementedException("AddForwardAccumulator");
+ }
+
+ // Order stack push after the successor of forward_index
+ var add_op = forward_index.op.inputs[0].op;
+ push.op._add_control_input(add_op);
+ return acc;
+ });
+ }
}
// """Add the getter for an accumulated value in the grad context.
@@ -225,6 +217,7 @@ namespace Tensorflow.Operations.ControlFlows
// Returns:
// The current value (the top of the stack).
// """
+
public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false)
{
throw new NotImplementedException();
@@ -261,62 +254,50 @@ namespace Tensorflow.Operations.ControlFlows
// return pop
}
- // def GetRealValue(self, value):
- // """Get the real value of `value`.
-
- // If backprop "uses" a value produced by forward inference, an accumulator
- // is added in the forward loop to accumulate its values. We use the
- // accumulated value. This method must be called in the grad loop context.
- // `value` must be in forward and needed for backprop.
-
- // Args:
- // value: A tensor to be captured.
-
- // Returns:
- // The same tensor obtained from the saved history.
- // """
- // assert value.op.type not in ["Variable", "VariableV2"]
- // real_value = self._history_map.get(value.name)
- // if real_value is None:
- // cur_value = value
- // cur_grad_state = self
- // while True:
- // enter_op = util.GetLoopConstantEnter(cur_value)
- // if enter_op:
- // # Special case: cur_value comes from a constant Enter node.
- // cur_value = enter_op.inputs[0]
- // cur_grad_state = cur_grad_state.outer_grad_state
- // if cur_grad_state is None:
- // # We are now outside all nested loops for this gradient(),
- // # so `value` is a loop invariant and there is no need to
- // # save the history of value. Just make cur_value to enter
- // # the right control flow context.
- // real_value = self._grad_context.AddValue(cur_value)
- // break
- // elif constant_op.is_constant(cur_value):
- // # If the value to be forwarded is a constant, clone the constant in
- // # the gradient loop rather than using a stack.
- // # TODO(phawkins): consider hoisting the constant out of the loop
- // # instead.
- // real_value = constant_op.constant(
- // tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
- // break
- // else:
- // # Record the history of this value in forward_ctxt.
- // self._grad_context.Exit()
- // history_value = cur_grad_state.AddForwardAccumulator(cur_value)
- // self._grad_context.Enter()
- // break
-
- // if real_value is None:
- // # Add the stack pop op in the grad context.
- // real_value = cur_grad_state.AddBackpropAccumulatedValue(
- // history_value, cur_value)
- // if cur_grad_state != self:
- // real_value = self._grad_context.AddValue(real_value)
- // self._history_map[value.name] = real_value
- // return real_value
-
-
+ ///
+ /// Get the real value of `value`.
+ ///
+ /// A tensor to be captured.
+ /// The same tensor obtained from the saved history.
+ public Tensor GetRealValue(Tensor value)
+ {
+ Tensor real_value = null;
+ if(real_value == null)
+ {
+ var cur_value = value;
+ var cur_grad_state = this;
+ Tensor history_value = null;
+ while (true)
+ {
+ var enter_op = util.GetLoopConstantEnter(cur_value);
+ if(enter_op != null)
+ {
+ throw new NotImplementedException("GetRealValue");
+ }
+ else if (constant_op.is_constant(cur_value))
+ {
+ throw new NotImplementedException("GetRealValue");
+ }
+ else
+ {
+ // Record the history of this value in forward_ctxt.
+ _grad_context.Exit();
+ history_value = cur_grad_state.AddForwardAccumulator(cur_value);
+ _grad_context.Enter();
+ break;
+ }
+ }
+
+ if(real_value == null)
+ {
+ // Add the stack pop op in the grad context.
+ real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value, cur_value);
+ if (cur_grad_state != this)
+ real_value = _grad_context.AddValue(real_value);
+ }
+ _history_map[value.name] = real_value;
+ }
+ return real_value;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
index 56bcf897..02a5a573 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
@@ -530,10 +530,9 @@ namespace Tensorflow.Operations
}
if(forward_ctxt == grad_ctxt.grad_state.forward_context)
{
- throw new NotImplementedException("forward_ctxt == grad_ctxt.grad_state.forward_context");
- /*real_val = grad_ctxt.grad_state.GetRealValue(val);
+ var real_val = grad_ctxt.grad_state.GetRealValue(val);
_external_values[val.name] = real_val;
- return real_val;*/
+ return real_val;
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
index a8a0e0b9..48af7d58 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -30,7 +30,7 @@ namespace Tensorflow.Operations
TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
{
- tf_with(tf.variable_scope("rnn"), scope =>
+ return tf_with(tf.variable_scope("rnn"), scope =>
{
VariableScope varscope = scope;
var flat_input = nest.flatten(inputs_tensor);
@@ -64,9 +64,12 @@ namespace Tensorflow.Operations
swap_memory: swap_memory,
sequence_length: sequence_length,
dtype: dtype);
- });
- throw new NotImplementedException("");
+ if (!time_major)
+ outputs = nest.map_structure(_transpose_batch_time, outputs);
+
+ return (outputs, final_state);
+ });
}
///
@@ -210,16 +213,28 @@ namespace Tensorflow.Operations
var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t);
// Keras RNN cells only accept state as list, even if it's a single tensor.
// var is_keras_rnn_cell = _is_keras_rnn_cell(cell);
- (Tensor, Tensor) a = (null, null);
+ Tensor[] outputs = null;
if (sequence_length != null)
throw new NotImplementedException("sequence_length != null");
else
- a = cell.__call__(input_t_t, state: state1);
+ outputs = cell.__call__(input_t_t, state: state1);
+
+ var (output, new_state) = (outputs[0], outputs[1]);
+ // Keras cells always wrap state as list, even if it's a single tensor.
+ // if(is_keras_rnn_cell && len(new_state)) == 1
+ // Pack state if using state tuples
+ outputs = nest.flatten2(output).Select(x => x as Tensor).ToArray();
- return item;
+ output_ta_t = zip(output_ta_t, outputs).Select(x =>
+ {
+ var(ta, @out) = (x.Item1, x.Item2);
+ return ta.write(item.time, @out);
+ }).ToArray();
+
+ return new BodyItemInRnnWhileLoop(item.time + 1, output_ta_t, new_state);
};
- control_flow_ops.while_loop(
+ var while_loop_result = control_flow_ops.while_loop(
cond: cond,
body: _time_step,
loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state),
@@ -227,7 +242,18 @@ namespace Tensorflow.Operations
maximum_iterations: time_steps,
swap_memory: swap_memory);
- throw new NotImplementedException("");
+ (_, TensorArray[] output_final_ta, Tensor final_state) = (while_loop_result.time, while_loop_result.output_ta_t, while_loop_result.state);
+
+ // Unpack final output if not using output tuples.
+ var final_outputs = output_final_ta.Select(ta => ta.stack()).ToArray();
+ // Restore some shape information
+ foreach (var (output, output_size) in zip(final_outputs, flat_output_size))
+ {
+ var shape = rnn_cell_impl._concat(new[] { const_time_steps, const_batch_size }, output_size, @static: true);
+ output.set_shape(shape);
+ }
+
+ return (final_outputs[0], final_state);
}
private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape)
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
index 3164ba14..cf5f1ce0 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
@@ -53,5 +53,34 @@ namespace Tensorflow.Operations
return array_ops.concat(new[] { p, s }, 0);
}
}
+
+ public static TensorShape _concat(int[] prefix, int suffix, bool @static = false)
+ {
+ var p = new TensorShape(prefix);
+ var p_static = prefix;
+ var p_tensor = p.is_fully_defined() ? constant_op.constant(p.as_list(), dtype: dtypes.int32) : null;
+
+ var s_tensor_shape = new TensorShape(suffix);
+ var s_static = s_tensor_shape.ndim > -1 ?
+ s_tensor_shape.dims :
+ null;
+ var s_tensor = s_tensor_shape.is_fully_defined() ?
+ constant_op.constant(s_tensor_shape.dims, dtype: dtypes.int32) :
+ null;
+
+ if (@static)
+ {
+ if (p_static is null) return null;
+ var shape = new TensorShape(p_static).concatenate(s_static);
+ return shape;
+ }
+ else
+ {
+ if (p is null || s_tensor is null)
+ throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}");
+ // return array_ops.concat(new[] { p_tensor, s_tensor }, 0);
+ throw new NotImplementedException("");
+ }
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
index 9f0cb9a5..c9ae7071 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
@@ -52,6 +52,10 @@ namespace Tensorflow
public void _set_control_flow_context(ControlFlowContext ctx)
{
+ if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc"))
+ {
+
+ }
_control_flow_context = ctx;
}
@@ -59,5 +63,10 @@ namespace Tensorflow
{
return _control_flow_context;
}
+
+ public WhileContext GetWhileContext()
+ {
+ return _control_flow_context as WhileContext;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Instance.cs b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs
index 6f6c8226..e39a34a3 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Instance.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs
@@ -15,17 +15,14 @@
******************************************************************************/
using System;
+using System.Linq;
using System.Collections.Generic;
+using static Tensorflow.Binding;
namespace Tensorflow
{
public partial class Operation
{
- // cache the mapping between managed and unmanaged op
- // some data is stored in managed instance, so when
- // create Operation by IntPtr, it will lost some data.
- private static Dictionary OpInstances = new Dictionary();
-
///
/// Get operation by handle
///
@@ -33,9 +30,17 @@ namespace Tensorflow
///
public Operation GetOperation(IntPtr handle)
{
- return OpInstances.ContainsKey(handle) ?
- OpInstances[handle] :
- new Operation(handle);
+ var nodes = tf.get_default_graph()._nodes_by_name;
+ foreach(var node in nodes.Values)
+ {
+ if (node is Operation op)
+ {
+ if (op == handle)
+ return op;
+ }
+ }
+
+ return null;
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index d5068f2e..e8eb216f 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -106,7 +106,6 @@ namespace Tensorflow
_control_flow_context = _graph._get_control_flow_context();
// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
- OpInstances[_handle] = this;
}
/*public Operation(Graph g, string opType, string oper_name)
@@ -183,10 +182,12 @@ namespace Tensorflow
// This will be set by self.inputs.
if (op_def == null)
op_def = g.GetOpDef(node_def.Op);
-
+ if(node_def.Name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc")
+ {
+
+ }
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
-
_is_stateful = op_def.IsStateful;
// Initialize self._outputs.
@@ -202,8 +203,6 @@ namespace Tensorflow
if (_handle != IntPtr.Zero)
_control_flow_post_processing();
-
- OpInstances[_handle] = this;
}
public void run(FeedItem[] feed_dict = null, Session session = null)
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index 01231035..cea3e440 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -183,7 +183,7 @@ namespace Tensorflow
{
var _op = _op_def_lib._apply_op_helper("Identity", name, new { input });
- return _op.outputs[0];
+ return _op.output;
}
public static Tensor invert_permutation(Tensor x, string name = null)
diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs
index 5f0ceb48..8f9c8120 100644
--- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs
@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/
+using Tensorflow.Operations;
+
namespace Tensorflow
{
public class gen_control_flow_ops
@@ -148,18 +150,18 @@ namespace Tensorflow
return new []{_op.outputs[0], _op.outputs[1]};
}
- public static Tensor[] ref_merge(Tensor[] inputs, string name = null)
+ public static MergeOutput ref_merge(Tensor[] inputs, string name = null)
{
var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs });
- return _op.outputs;
+ return new MergeOutput(_op.outputs);
}
- public static Tensor[] merge(Tensor[] inputs, string name = null)
+ public static MergeOutput merge(Tensor[] inputs, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs });
- return _op.outputs;
+ return new MergeOutput(_op.outputs);
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
index 52b0a372..fcb1000f 100644
--- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
@@ -259,5 +259,31 @@ namespace Tensorflow
return _op.output;
}
+
+ public static Tensor stack_v2(Tensor max_size, TF_DataType elem_type, string stack_name = "",
+ string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("StackV2", name, new
+ {
+ max_size,
+ elem_type,
+ stack_name
+ });
+
+ return _op.output;
+ }
+
+ public static Tensor stack_push_v2(Tensor handle, Tensor elem, bool swap_memory = false,
+ string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("StackPushV2", name, new
+ {
+ handle,
+ elem,
+ swap_memory
+ });
+
+ return _op.output;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index 08431089..7e54349f 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -282,7 +282,7 @@ namespace Tensorflow
///
///
///
- public static Tensor tanh_grad(Tensor y, Tensor dy, string name = "TanhGrad")
+ public static Tensor tanh_grad(Tensor y, Tensor dy, string name = null)
=> _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output;
public static Tensor floor(Tensor x, string name = null)
diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs
index 54149fe1..7dbacea0 100644
--- a/src/TensorFlowNET.Core/Util/nest.py.cs
+++ b/src/TensorFlowNET.Core/Util/nest.py.cs
@@ -526,6 +526,14 @@ namespace Tensorflow.Util
return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
}
+ public static Tensor map_structure2(Func func, T structure)
+ {
+ var flat_structure = flatten(structure);
+ var mapped_flat_structure = flat_structure.Select(func).ToList();
+
+ return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
+ }
+
///
/// Same as map_structure, but with only one structure (no combining of multiple structures)
///