diff --git a/README.md b/README.md index a80191a7..8744ba72 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) [![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US) -TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). +TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). ![tensors_flowing](docs/assets/tensors_flowing.gif) diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index c70af1fd..34f227fc 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -30,6 +30,20 @@ namespace Tensorflow /// public static partial class Binding { + public static T2 get(this Dictionary dict, T1 key) + => key == null ? + default(T2) : + (dict.ContainsKey(key) ? dict[key] : default(T2)); + + public static void add(this IList list, T element) + => list.Add(element); + + public static void append(this IList list, T element) + => list.Add(element); + + public static void extend(this List list, IEnumerable elements) + => list.AddRange(elements); + private static string _tostring(object obj) { switch (obj) @@ -81,6 +95,9 @@ namespace Tensorflow throw new NotImplementedException("len() not implemented for type: " + a.GetType()); } + public static T[] list(IEnumerable list) + => list.ToArray(); + public static IEnumerable range(int end) { return Enumerable.Range(0, end); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs index d1be6f31..1d296774 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs @@ -109,11 +109,12 @@ namespace Tensorflow.Operations.ControlFlows grad_state.grad_context.Enter(); } - // def ExitGradWhileContext(self, op, before): - // """Exit the WhileContext for gradient computation.""" - // grad_state = self.GetGradState(op, before) - // if grad_state: - // grad_state.grad_context.Exit() + public void ExitGradWhileContext(Operation op, bool before) + { + var grad_state = GetGradState(op, before); + if (grad_state != null) + grad_state.grad_context.Exit(); + } // def AddWhileContext(self, op, between_op_list, between_ops): // """Add the grad state for the while loop that op belongs to. @@ -287,51 +288,9 @@ namespace Tensorflow.Operations.ControlFlows return result; } - // def PostProcessing(self): - // """Perform postprocessing at the end of gradients(). - - // We have created the gradient graph at this point. So this function - // can be used to perform any postprocessing on the gradient graph. - // We currently perform the following postprocessing: - // 1. Patch the gradient graph if the output of a loop variable - // doesn't depend on its input. - // """ - // for _, grad_state in self._map.items(): - // for _, b_merge in grad_state.switch_map.items(): - // if b_merge.op.inputs[0] == b_merge.op.inputs[1]: - // # The value of this loop variable at iteration i+1 doesn't - // # depend on its value at iteration i. So use zeros as the - // # gradients for all iterations > 0. - // dtype = b_merge.op.inputs[0].dtype - // shape = b_merge.op.inputs[0].get_shape() - // # pylint: disable=protected-access - // if shape.is_fully_defined(): - // grad_state.grad_context.Enter() - // # Create a zeros and use it for iterations > 0. - // grad_val = constant_op.constant(0, dtype=dtype, shape=shape) - // next_grad_val = _NextIteration(grad_val) - // grad_state.grad_context.Exit() - // else: - // # Create a zeros in the outer grad context. - // outer_grad_ctxt = grad_state.grad_context.outer_context - // if outer_grad_ctxt: - // outer_grad_ctxt.Enter() - // enter_grad_op = b_merge.op.inputs[0].op - // enter_grad = enter_grad_op.inputs[0] - // grad_shape = array_ops.shape_internal(enter_grad, optimize=False) - // grad_val = array_ops.zeros(grad_shape) - // if outer_grad_ctxt: - // outer_grad_ctxt.Exit() - // # Use the zeros for iterations > 0. - // grad_state.grad_context.Enter() - // next_grad_val = _NextIteration(grad_val) - // grad_state.grad_context.Exit() - // b_merge.op._update_input(1, next_grad_val) - // # pylint: enable=protected-access - + public void PostProcessing() + { + throw new NotImplementedException("PostProcessing"); + } } - - - - } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs index e17ab8ba..143aacb1 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs @@ -17,7 +17,9 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Linq; using static Tensorflow.Binding; +using util = Tensorflow.control_flow_util; namespace Tensorflow.Operations.ControlFlows { @@ -56,6 +58,7 @@ namespace Tensorflow.Operations.ControlFlows public GradLoopState outer_grad_state => _outer_grad_state; Tensor _forward_index; + public Tensor forward_index => _forward_index; Tensor _grad_index; Tensor[] _forward_loop_exits; @@ -152,63 +155,52 @@ namespace Tensorflow.Operations.ControlFlows /// The stack that contains the accumulated history of the tensor. public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) { - throw new NotImplementedException("AddForwardAccumulator"); - // # curr_ctxt is the context that tf.gradients was called in. - // with self._forward_index.graph.as_default(): - // curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access - // with ops.control_dependencies(None): - // if curr_ctxt: - // curr_ctxt.Enter() - // with ops.colocate_with(value): - // # We only need to pass maximum_iterations to the stack if - // # we're inside an XLA context. - // if not util.IsInXLAContext(value.op): - // max_size = constant_op.constant(-1, dtypes.int32) - // else: - // max_size = GetMaxSizeFromNestedMaximumIterations( - // value, self.forward_context) - // acc = gen_data_flow_ops.stack_v2( - // max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") - // if curr_ctxt: - // curr_ctxt.Exit() - - // # Make acc available in the forward context. - // enter_acc = self.forward_context.AddValue(acc) - - // # Add the stack_push op in the context of value.op. - // swap_enabled = self.forward_context.swap_memory - // value_ctxt = util.GetOutputContext(value.op) - // if value_ctxt == self.forward_context: - // # value is not nested in the forward context. - // self.forward_context.Enter() - // push = gen_data_flow_ops.stack_push_v2( - // enter_acc, value, swap_memory=swap_enabled) - // self.forward_context.Exit() - // # Protect stack push and order it before forward_index. - // self.forward_index.op._add_control_input(push.op) - // else: - // # value is in a cond context within the forward context. - // if not isinstance(value_ctxt, CondContext): - // raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) - // if dead_branch: - // # The special case for creating a zero tensor for a dead - // # branch of a switch. See ControlFlowState.ZerosLike(). - // value_ctxt.outer_context.Enter() - // push = gen_data_flow_ops.stack_push_v2( - // enter_acc, value, swap_memory=swap_enabled) - // value_ctxt.outer_context.Exit() - // push.op._set_control_flow_context(value_ctxt) - // else: - // value_ctxt.Enter() - // push = gen_data_flow_ops.stack_push_v2( - // enter_acc, value, swap_memory=swap_enabled) - // value_ctxt.Exit() - // # Protect stack push and order it before forward_sync. - // self.forward_sync._add_control_input(push.op) - // # Order stack push after the successor of forward_index - // add_op = self.forward_index.op.inputs[0].op - // push.op._add_control_input(add_op) - // return acc + using (_forward_index.graph.as_default()) + { + var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); + return tf_with(ops.control_dependencies(null), delegate + { + Tensor acc = null; + Tensor push = null; + if (curr_ctxt != null) + curr_ctxt.Enter(); + ops.colocate_with(value); + { + // We only need to pass maximum_iterations to the stack if + // we're inside an XLA context. + var max_size = constant_op.constant(-1, dtypes.int32); + acc = gen_data_flow_ops.stack_v2( + max_size: max_size, elem_type: value.dtype.as_base_dtype(), name: "f_acc"); + } + if (curr_ctxt != null) + curr_ctxt.Exit(); + + // Make acc available in the forward context. + var enter_acc = forward_context.AddValue(acc); + + // Add the stack_push op in the context of value.op. + var swap_enabled = forward_context.swap_memory; + var value_ctxt = util.GetOutputContext(value.op); + if(value_ctxt == forward_context) + { + // value is not nested in the forward context. + forward_context.Enter(); + push = gen_data_flow_ops.stack_push_v2(enter_acc, value, swap_memory: swap_enabled); + forward_context.Exit(); + // Protect stack push and order it before forward_index. + forward_index.op._add_control_input(push.op); + } + else + { + throw new NotImplementedException("AddForwardAccumulator"); + } + + // Order stack push after the successor of forward_index + var add_op = forward_index.op.inputs[0].op; + push.op._add_control_input(add_op); + return acc; + }); + } } // """Add the getter for an accumulated value in the grad context. @@ -225,6 +217,7 @@ namespace Tensorflow.Operations.ControlFlows // Returns: // The current value (the top of the stack). // """ + public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) { throw new NotImplementedException(); @@ -261,62 +254,50 @@ namespace Tensorflow.Operations.ControlFlows // return pop } - // def GetRealValue(self, value): - // """Get the real value of `value`. - - // If backprop "uses" a value produced by forward inference, an accumulator - // is added in the forward loop to accumulate its values. We use the - // accumulated value. This method must be called in the grad loop context. - // `value` must be in forward and needed for backprop. - - // Args: - // value: A tensor to be captured. - - // Returns: - // The same tensor obtained from the saved history. - // """ - // assert value.op.type not in ["Variable", "VariableV2"] - // real_value = self._history_map.get(value.name) - // if real_value is None: - // cur_value = value - // cur_grad_state = self - // while True: - // enter_op = util.GetLoopConstantEnter(cur_value) - // if enter_op: - // # Special case: cur_value comes from a constant Enter node. - // cur_value = enter_op.inputs[0] - // cur_grad_state = cur_grad_state.outer_grad_state - // if cur_grad_state is None: - // # We are now outside all nested loops for this gradient(), - // # so `value` is a loop invariant and there is no need to - // # save the history of value. Just make cur_value to enter - // # the right control flow context. - // real_value = self._grad_context.AddValue(cur_value) - // break - // elif constant_op.is_constant(cur_value): - // # If the value to be forwarded is a constant, clone the constant in - // # the gradient loop rather than using a stack. - // # TODO(phawkins): consider hoisting the constant out of the loop - // # instead. - // real_value = constant_op.constant( - // tensor_util.constant_value(cur_value), dtype=cur_value.dtype) - // break - // else: - // # Record the history of this value in forward_ctxt. - // self._grad_context.Exit() - // history_value = cur_grad_state.AddForwardAccumulator(cur_value) - // self._grad_context.Enter() - // break - - // if real_value is None: - // # Add the stack pop op in the grad context. - // real_value = cur_grad_state.AddBackpropAccumulatedValue( - // history_value, cur_value) - // if cur_grad_state != self: - // real_value = self._grad_context.AddValue(real_value) - // self._history_map[value.name] = real_value - // return real_value - - + /// + /// Get the real value of `value`. + /// + /// A tensor to be captured. + /// The same tensor obtained from the saved history. + public Tensor GetRealValue(Tensor value) + { + Tensor real_value = null; + if(real_value == null) + { + var cur_value = value; + var cur_grad_state = this; + Tensor history_value = null; + while (true) + { + var enter_op = util.GetLoopConstantEnter(cur_value); + if(enter_op != null) + { + throw new NotImplementedException("GetRealValue"); + } + else if (constant_op.is_constant(cur_value)) + { + throw new NotImplementedException("GetRealValue"); + } + else + { + // Record the history of this value in forward_ctxt. + _grad_context.Exit(); + history_value = cur_grad_state.AddForwardAccumulator(cur_value); + _grad_context.Enter(); + break; + } + } + + if(real_value == null) + { + // Add the stack pop op in the grad context. + real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value, cur_value); + if (cur_grad_state != this) + real_value = _grad_context.AddValue(real_value); + } + _history_map[value.name] = real_value; + } + return real_value; + } } } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 56bcf897..02a5a573 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -530,10 +530,9 @@ namespace Tensorflow.Operations } if(forward_ctxt == grad_ctxt.grad_state.forward_context) { - throw new NotImplementedException("forward_ctxt == grad_ctxt.grad_state.forward_context"); - /*real_val = grad_ctxt.grad_state.GetRealValue(val); + var real_val = grad_ctxt.grad_state.GetRealValue(val); _external_values[val.name] = real_val; - return real_val;*/ + return real_val; } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index a8a0e0b9..48af7d58 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -30,7 +30,7 @@ namespace Tensorflow.Operations TF_DataType dtype = TF_DataType.DtInvalid, int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) { - tf_with(tf.variable_scope("rnn"), scope => + return tf_with(tf.variable_scope("rnn"), scope => { VariableScope varscope = scope; var flat_input = nest.flatten(inputs_tensor); @@ -64,9 +64,12 @@ namespace Tensorflow.Operations swap_memory: swap_memory, sequence_length: sequence_length, dtype: dtype); - }); - throw new NotImplementedException(""); + if (!time_major) + outputs = nest.map_structure(_transpose_batch_time, outputs); + + return (outputs, final_state); + }); } /// @@ -210,16 +213,28 @@ namespace Tensorflow.Operations var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t); // Keras RNN cells only accept state as list, even if it's a single tensor. // var is_keras_rnn_cell = _is_keras_rnn_cell(cell); - (Tensor, Tensor) a = (null, null); + Tensor[] outputs = null; if (sequence_length != null) throw new NotImplementedException("sequence_length != null"); else - a = cell.__call__(input_t_t, state: state1); + outputs = cell.__call__(input_t_t, state: state1); + + var (output, new_state) = (outputs[0], outputs[1]); + // Keras cells always wrap state as list, even if it's a single tensor. + // if(is_keras_rnn_cell && len(new_state)) == 1 + // Pack state if using state tuples + outputs = nest.flatten2(output).Select(x => x as Tensor).ToArray(); - return item; + output_ta_t = zip(output_ta_t, outputs).Select(x => + { + var(ta, @out) = (x.Item1, x.Item2); + return ta.write(item.time, @out); + }).ToArray(); + + return new BodyItemInRnnWhileLoop(item.time + 1, output_ta_t, new_state); }; - control_flow_ops.while_loop( + var while_loop_result = control_flow_ops.while_loop( cond: cond, body: _time_step, loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), @@ -227,7 +242,18 @@ namespace Tensorflow.Operations maximum_iterations: time_steps, swap_memory: swap_memory); - throw new NotImplementedException(""); + (_, TensorArray[] output_final_ta, Tensor final_state) = (while_loop_result.time, while_loop_result.output_ta_t, while_loop_result.state); + + // Unpack final output if not using output tuples. + var final_outputs = output_final_ta.Select(ta => ta.stack()).ToArray(); + // Restore some shape information + foreach (var (output, output_size) in zip(final_outputs, flat_output_size)) + { + var shape = rnn_cell_impl._concat(new[] { const_time_steps, const_batch_size }, output_size, @static: true); + output.set_shape(shape); + } + + return (final_outputs[0], final_state); } private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs index 3164ba14..cf5f1ce0 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs @@ -53,5 +53,34 @@ namespace Tensorflow.Operations return array_ops.concat(new[] { p, s }, 0); } } + + public static TensorShape _concat(int[] prefix, int suffix, bool @static = false) + { + var p = new TensorShape(prefix); + var p_static = prefix; + var p_tensor = p.is_fully_defined() ? constant_op.constant(p.as_list(), dtype: dtypes.int32) : null; + + var s_tensor_shape = new TensorShape(suffix); + var s_static = s_tensor_shape.ndim > -1 ? + s_tensor_shape.dims : + null; + var s_tensor = s_tensor_shape.is_fully_defined() ? + constant_op.constant(s_tensor_shape.dims, dtype: dtypes.int32) : + null; + + if (@static) + { + if (p_static is null) return null; + var shape = new TensorShape(p_static).concatenate(s_static); + return shape; + } + else + { + if (p is null || s_tensor is null) + throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}"); + // return array_ops.concat(new[] { p_tensor, s_tensor }, 0); + throw new NotImplementedException(""); + } + } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 9f0cb9a5..c9ae7071 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -52,6 +52,10 @@ namespace Tensorflow public void _set_control_flow_context(ControlFlowContext ctx) { + if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc")) + { + + } _control_flow_context = ctx; } @@ -59,5 +63,10 @@ namespace Tensorflow { return _control_flow_context; } + + public WhileContext GetWhileContext() + { + return _control_flow_context as WhileContext; + } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Instance.cs b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs index 6f6c8226..e39a34a3 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Instance.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs @@ -15,17 +15,14 @@ ******************************************************************************/ using System; +using System.Linq; using System.Collections.Generic; +using static Tensorflow.Binding; namespace Tensorflow { public partial class Operation { - // cache the mapping between managed and unmanaged op - // some data is stored in managed instance, so when - // create Operation by IntPtr, it will lost some data. - private static Dictionary OpInstances = new Dictionary(); - /// /// Get operation by handle /// @@ -33,9 +30,17 @@ namespace Tensorflow /// public Operation GetOperation(IntPtr handle) { - return OpInstances.ContainsKey(handle) ? - OpInstances[handle] : - new Operation(handle); + var nodes = tf.get_default_graph()._nodes_by_name; + foreach(var node in nodes.Values) + { + if (node is Operation op) + { + if (op == handle) + return op; + } + } + + return null; } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index d5068f2e..e8eb216f 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -106,7 +106,6 @@ namespace Tensorflow _control_flow_context = _graph._get_control_flow_context(); // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. - OpInstances[_handle] = this; } /*public Operation(Graph g, string opType, string oper_name) @@ -183,10 +182,12 @@ namespace Tensorflow // This will be set by self.inputs. if (op_def == null) op_def = g.GetOpDef(node_def.Op); - + if(node_def.Name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc") + { + + } var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); - _is_stateful = op_def.IsStateful; // Initialize self._outputs. @@ -202,8 +203,6 @@ namespace Tensorflow if (_handle != IntPtr.Zero) _control_flow_post_processing(); - - OpInstances[_handle] = this; } public void run(FeedItem[] feed_dict = null, Session session = null) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 01231035..cea3e440 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -183,7 +183,7 @@ namespace Tensorflow { var _op = _op_def_lib._apply_op_helper("Identity", name, new { input }); - return _op.outputs[0]; + return _op.output; } public static Tensor invert_permutation(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs index 5f0ceb48..8f9c8120 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Operations; + namespace Tensorflow { public class gen_control_flow_ops @@ -148,18 +150,18 @@ namespace Tensorflow return new []{_op.outputs[0], _op.outputs[1]}; } - public static Tensor[] ref_merge(Tensor[] inputs, string name = null) + public static MergeOutput ref_merge(Tensor[] inputs, string name = null) { var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); - return _op.outputs; + return new MergeOutput(_op.outputs); } - public static Tensor[] merge(Tensor[] inputs, string name = null) + public static MergeOutput merge(Tensor[] inputs, string name = null) { var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); - return _op.outputs; + return new MergeOutput(_op.outputs); } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index 52b0a372..fcb1000f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -259,5 +259,31 @@ namespace Tensorflow return _op.output; } + + public static Tensor stack_v2(Tensor max_size, TF_DataType elem_type, string stack_name = "", + string name = null) + { + var _op = _op_def_lib._apply_op_helper("StackV2", name, new + { + max_size, + elem_type, + stack_name + }); + + return _op.output; + } + + public static Tensor stack_push_v2(Tensor handle, Tensor elem, bool swap_memory = false, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("StackPushV2", name, new + { + handle, + elem, + swap_memory + }); + + return _op.output; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 08431089..7e54349f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -282,7 +282,7 @@ namespace Tensorflow /// /// /// - public static Tensor tanh_grad(Tensor y, Tensor dy, string name = "TanhGrad") + public static Tensor tanh_grad(Tensor y, Tensor dy, string name = null) => _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; public static Tensor floor(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 54149fe1..7dbacea0 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -526,6 +526,14 @@ namespace Tensorflow.Util return pack_sequence_as(structure, mapped_flat_structure) as Tensor; } + public static Tensor map_structure2(Func func, T structure) + { + var flat_structure = flatten(structure); + var mapped_flat_structure = flat_structure.Select(func).ToList(); + + return pack_sequence_as(structure, mapped_flat_structure) as Tensor; + } + /// /// Same as map_structure, but with only one structure (no combining of multiple structures) ///