| @@ -9,7 +9,7 @@ | |||||
| [](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | [](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | ||||
| [](https://996.icu/#/en_US) | [](https://996.icu/#/en_US) | ||||
| TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). <a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp_badge.png" width="200" height="200" align="right" /></a> | |||||
| TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | |||||
|  |  | ||||
| @@ -30,6 +30,20 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public static partial class Binding | public static partial class Binding | ||||
| { | { | ||||
| public static T2 get<T1, T2>(this Dictionary<T1, T2> dict, T1 key) | |||||
| => key == null ? | |||||
| default(T2) : | |||||
| (dict.ContainsKey(key) ? dict[key] : default(T2)); | |||||
| public static void add<T>(this IList<T> list, T element) | |||||
| => list.Add(element); | |||||
| public static void append<T>(this IList<T> list, T element) | |||||
| => list.Add(element); | |||||
| public static void extend<T>(this List<T> list, IEnumerable<T> elements) | |||||
| => list.AddRange(elements); | |||||
| private static string _tostring(object obj) | private static string _tostring(object obj) | ||||
| { | { | ||||
| switch (obj) | switch (obj) | ||||
| @@ -81,6 +95,9 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("len() not implemented for type: " + a.GetType()); | throw new NotImplementedException("len() not implemented for type: " + a.GetType()); | ||||
| } | } | ||||
| public static T[] list<T>(IEnumerable<T> list) | |||||
| => list.ToArray(); | |||||
| public static IEnumerable<int> range(int end) | public static IEnumerable<int> range(int end) | ||||
| { | { | ||||
| return Enumerable.Range(0, end); | return Enumerable.Range(0, end); | ||||
| @@ -109,11 +109,12 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| grad_state.grad_context.Enter(); | grad_state.grad_context.Enter(); | ||||
| } | } | ||||
| // def ExitGradWhileContext(self, op, before): | |||||
| // """Exit the WhileContext for gradient computation.""" | |||||
| // grad_state = self.GetGradState(op, before) | |||||
| // if grad_state: | |||||
| // grad_state.grad_context.Exit() | |||||
| public void ExitGradWhileContext(Operation op, bool before) | |||||
| { | |||||
| var grad_state = GetGradState(op, before); | |||||
| if (grad_state != null) | |||||
| grad_state.grad_context.Exit(); | |||||
| } | |||||
| // def AddWhileContext(self, op, between_op_list, between_ops): | // def AddWhileContext(self, op, between_op_list, between_ops): | ||||
| // """Add the grad state for the while loop that op belongs to. | // """Add the grad state for the while loop that op belongs to. | ||||
| @@ -287,51 +288,9 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| return result; | return result; | ||||
| } | } | ||||
| // def PostProcessing(self): | |||||
| // """Perform postprocessing at the end of gradients(). | |||||
| // We have created the gradient graph at this point. So this function | |||||
| // can be used to perform any postprocessing on the gradient graph. | |||||
| // We currently perform the following postprocessing: | |||||
| // 1. Patch the gradient graph if the output of a loop variable | |||||
| // doesn't depend on its input. | |||||
| // """ | |||||
| // for _, grad_state in self._map.items(): | |||||
| // for _, b_merge in grad_state.switch_map.items(): | |||||
| // if b_merge.op.inputs[0] == b_merge.op.inputs[1]: | |||||
| // # The value of this loop variable at iteration i+1 doesn't | |||||
| // # depend on its value at iteration i. So use zeros as the | |||||
| // # gradients for all iterations > 0. | |||||
| // dtype = b_merge.op.inputs[0].dtype | |||||
| // shape = b_merge.op.inputs[0].get_shape() | |||||
| // # pylint: disable=protected-access | |||||
| // if shape.is_fully_defined(): | |||||
| // grad_state.grad_context.Enter() | |||||
| // # Create a zeros and use it for iterations > 0. | |||||
| // grad_val = constant_op.constant(0, dtype=dtype, shape=shape) | |||||
| // next_grad_val = _NextIteration(grad_val) | |||||
| // grad_state.grad_context.Exit() | |||||
| // else: | |||||
| // # Create a zeros in the outer grad context. | |||||
| // outer_grad_ctxt = grad_state.grad_context.outer_context | |||||
| // if outer_grad_ctxt: | |||||
| // outer_grad_ctxt.Enter() | |||||
| // enter_grad_op = b_merge.op.inputs[0].op | |||||
| // enter_grad = enter_grad_op.inputs[0] | |||||
| // grad_shape = array_ops.shape_internal(enter_grad, optimize=False) | |||||
| // grad_val = array_ops.zeros(grad_shape) | |||||
| // if outer_grad_ctxt: | |||||
| // outer_grad_ctxt.Exit() | |||||
| // # Use the zeros for iterations > 0. | |||||
| // grad_state.grad_context.Enter() | |||||
| // next_grad_val = _NextIteration(grad_val) | |||||
| // grad_state.grad_context.Exit() | |||||
| // b_merge.op._update_input(1, next_grad_val) | |||||
| // # pylint: enable=protected-access | |||||
| public void PostProcessing() | |||||
| { | |||||
| throw new NotImplementedException("PostProcessing"); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -17,7 +17,9 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | using System.Collections; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using util = Tensorflow.control_flow_util; | |||||
| namespace Tensorflow.Operations.ControlFlows | namespace Tensorflow.Operations.ControlFlows | ||||
| { | { | ||||
| @@ -56,6 +58,7 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| public GradLoopState outer_grad_state => _outer_grad_state; | public GradLoopState outer_grad_state => _outer_grad_state; | ||||
| Tensor _forward_index; | Tensor _forward_index; | ||||
| public Tensor forward_index => _forward_index; | |||||
| Tensor _grad_index; | Tensor _grad_index; | ||||
| Tensor[] _forward_loop_exits; | Tensor[] _forward_loop_exits; | ||||
| @@ -152,63 +155,52 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| /// <returns>The stack that contains the accumulated history of the tensor.</returns> | /// <returns>The stack that contains the accumulated history of the tensor.</returns> | ||||
| public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) | public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) | ||||
| { | { | ||||
| throw new NotImplementedException("AddForwardAccumulator"); | |||||
| // # curr_ctxt is the context that tf.gradients was called in. | |||||
| // with self._forward_index.graph.as_default(): | |||||
| // curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access | |||||
| // with ops.control_dependencies(None): | |||||
| // if curr_ctxt: | |||||
| // curr_ctxt.Enter() | |||||
| // with ops.colocate_with(value): | |||||
| // # We only need to pass maximum_iterations to the stack if | |||||
| // # we're inside an XLA context. | |||||
| // if not util.IsInXLAContext(value.op): | |||||
| // max_size = constant_op.constant(-1, dtypes.int32) | |||||
| // else: | |||||
| // max_size = GetMaxSizeFromNestedMaximumIterations( | |||||
| // value, self.forward_context) | |||||
| // acc = gen_data_flow_ops.stack_v2( | |||||
| // max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") | |||||
| // if curr_ctxt: | |||||
| // curr_ctxt.Exit() | |||||
| // # Make acc available in the forward context. | |||||
| // enter_acc = self.forward_context.AddValue(acc) | |||||
| // # Add the stack_push op in the context of value.op. | |||||
| // swap_enabled = self.forward_context.swap_memory | |||||
| // value_ctxt = util.GetOutputContext(value.op) | |||||
| // if value_ctxt == self.forward_context: | |||||
| // # value is not nested in the forward context. | |||||
| // self.forward_context.Enter() | |||||
| // push = gen_data_flow_ops.stack_push_v2( | |||||
| // enter_acc, value, swap_memory=swap_enabled) | |||||
| // self.forward_context.Exit() | |||||
| // # Protect stack push and order it before forward_index. | |||||
| // self.forward_index.op._add_control_input(push.op) | |||||
| // else: | |||||
| // # value is in a cond context within the forward context. | |||||
| // if not isinstance(value_ctxt, CondContext): | |||||
| // raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) | |||||
| // if dead_branch: | |||||
| // # The special case for creating a zero tensor for a dead | |||||
| // # branch of a switch. See ControlFlowState.ZerosLike(). | |||||
| // value_ctxt.outer_context.Enter() | |||||
| // push = gen_data_flow_ops.stack_push_v2( | |||||
| // enter_acc, value, swap_memory=swap_enabled) | |||||
| // value_ctxt.outer_context.Exit() | |||||
| // push.op._set_control_flow_context(value_ctxt) | |||||
| // else: | |||||
| // value_ctxt.Enter() | |||||
| // push = gen_data_flow_ops.stack_push_v2( | |||||
| // enter_acc, value, swap_memory=swap_enabled) | |||||
| // value_ctxt.Exit() | |||||
| // # Protect stack push and order it before forward_sync. | |||||
| // self.forward_sync._add_control_input(push.op) | |||||
| // # Order stack push after the successor of forward_index | |||||
| // add_op = self.forward_index.op.inputs[0].op | |||||
| // push.op._add_control_input(add_op) | |||||
| // return acc | |||||
| using (_forward_index.graph.as_default()) | |||||
| { | |||||
| var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); | |||||
| return tf_with(ops.control_dependencies(null), delegate | |||||
| { | |||||
| Tensor acc = null; | |||||
| Tensor push = null; | |||||
| if (curr_ctxt != null) | |||||
| curr_ctxt.Enter(); | |||||
| ops.colocate_with(value); | |||||
| { | |||||
| // We only need to pass maximum_iterations to the stack if | |||||
| // we're inside an XLA context. | |||||
| var max_size = constant_op.constant(-1, dtypes.int32); | |||||
| acc = gen_data_flow_ops.stack_v2( | |||||
| max_size: max_size, elem_type: value.dtype.as_base_dtype(), name: "f_acc"); | |||||
| } | |||||
| if (curr_ctxt != null) | |||||
| curr_ctxt.Exit(); | |||||
| // Make acc available in the forward context. | |||||
| var enter_acc = forward_context.AddValue(acc); | |||||
| // Add the stack_push op in the context of value.op. | |||||
| var swap_enabled = forward_context.swap_memory; | |||||
| var value_ctxt = util.GetOutputContext(value.op); | |||||
| if(value_ctxt == forward_context) | |||||
| { | |||||
| // value is not nested in the forward context. | |||||
| forward_context.Enter(); | |||||
| push = gen_data_flow_ops.stack_push_v2(enter_acc, value, swap_memory: swap_enabled); | |||||
| forward_context.Exit(); | |||||
| // Protect stack push and order it before forward_index. | |||||
| forward_index.op._add_control_input(push.op); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException("AddForwardAccumulator"); | |||||
| } | |||||
| // Order stack push after the successor of forward_index | |||||
| var add_op = forward_index.op.inputs[0].op; | |||||
| push.op._add_control_input(add_op); | |||||
| return acc; | |||||
| }); | |||||
| } | |||||
| } | } | ||||
| // """Add the getter for an accumulated value in the grad context. | // """Add the getter for an accumulated value in the grad context. | ||||
| @@ -225,6 +217,7 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| // Returns: | // Returns: | ||||
| // The current value (the top of the stack). | // The current value (the top of the stack). | ||||
| // """ | // """ | ||||
| public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) | public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| @@ -261,62 +254,50 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| // return pop | // return pop | ||||
| } | } | ||||
| // def GetRealValue(self, value): | |||||
| // """Get the real value of `value`. | |||||
| // If backprop "uses" a value produced by forward inference, an accumulator | |||||
| // is added in the forward loop to accumulate its values. We use the | |||||
| // accumulated value. This method must be called in the grad loop context. | |||||
| // `value` must be in forward and needed for backprop. | |||||
| // Args: | |||||
| // value: A tensor to be captured. | |||||
| // Returns: | |||||
| // The same tensor obtained from the saved history. | |||||
| // """ | |||||
| // assert value.op.type not in ["Variable", "VariableV2"] | |||||
| // real_value = self._history_map.get(value.name) | |||||
| // if real_value is None: | |||||
| // cur_value = value | |||||
| // cur_grad_state = self | |||||
| // while True: | |||||
| // enter_op = util.GetLoopConstantEnter(cur_value) | |||||
| // if enter_op: | |||||
| // # Special case: cur_value comes from a constant Enter node. | |||||
| // cur_value = enter_op.inputs[0] | |||||
| // cur_grad_state = cur_grad_state.outer_grad_state | |||||
| // if cur_grad_state is None: | |||||
| // # We are now outside all nested loops for this gradient(), | |||||
| // # so `value` is a loop invariant and there is no need to | |||||
| // # save the history of value. Just make cur_value to enter | |||||
| // # the right control flow context. | |||||
| // real_value = self._grad_context.AddValue(cur_value) | |||||
| // break | |||||
| // elif constant_op.is_constant(cur_value): | |||||
| // # If the value to be forwarded is a constant, clone the constant in | |||||
| // # the gradient loop rather than using a stack. | |||||
| // # TODO(phawkins): consider hoisting the constant out of the loop | |||||
| // # instead. | |||||
| // real_value = constant_op.constant( | |||||
| // tensor_util.constant_value(cur_value), dtype=cur_value.dtype) | |||||
| // break | |||||
| // else: | |||||
| // # Record the history of this value in forward_ctxt. | |||||
| // self._grad_context.Exit() | |||||
| // history_value = cur_grad_state.AddForwardAccumulator(cur_value) | |||||
| // self._grad_context.Enter() | |||||
| // break | |||||
| // if real_value is None: | |||||
| // # Add the stack pop op in the grad context. | |||||
| // real_value = cur_grad_state.AddBackpropAccumulatedValue( | |||||
| // history_value, cur_value) | |||||
| // if cur_grad_state != self: | |||||
| // real_value = self._grad_context.AddValue(real_value) | |||||
| // self._history_map[value.name] = real_value | |||||
| // return real_value | |||||
| /// <summary> | |||||
| /// Get the real value of `value`. | |||||
| /// </summary> | |||||
| /// <param name="value">A tensor to be captured.</param> | |||||
| /// <returns>The same tensor obtained from the saved history.</returns> | |||||
| public Tensor GetRealValue(Tensor value) | |||||
| { | |||||
| Tensor real_value = null; | |||||
| if(real_value == null) | |||||
| { | |||||
| var cur_value = value; | |||||
| var cur_grad_state = this; | |||||
| Tensor history_value = null; | |||||
| while (true) | |||||
| { | |||||
| var enter_op = util.GetLoopConstantEnter(cur_value); | |||||
| if(enter_op != null) | |||||
| { | |||||
| throw new NotImplementedException("GetRealValue"); | |||||
| } | |||||
| else if (constant_op.is_constant(cur_value)) | |||||
| { | |||||
| throw new NotImplementedException("GetRealValue"); | |||||
| } | |||||
| else | |||||
| { | |||||
| // Record the history of this value in forward_ctxt. | |||||
| _grad_context.Exit(); | |||||
| history_value = cur_grad_state.AddForwardAccumulator(cur_value); | |||||
| _grad_context.Enter(); | |||||
| break; | |||||
| } | |||||
| } | |||||
| if(real_value == null) | |||||
| { | |||||
| // Add the stack pop op in the grad context. | |||||
| real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value, cur_value); | |||||
| if (cur_grad_state != this) | |||||
| real_value = _grad_context.AddValue(real_value); | |||||
| } | |||||
| _history_map[value.name] = real_value; | |||||
| } | |||||
| return real_value; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -530,10 +530,9 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| if(forward_ctxt == grad_ctxt.grad_state.forward_context) | if(forward_ctxt == grad_ctxt.grad_state.forward_context) | ||||
| { | { | ||||
| throw new NotImplementedException("forward_ctxt == grad_ctxt.grad_state.forward_context"); | |||||
| /*real_val = grad_ctxt.grad_state.GetRealValue(val); | |||||
| var real_val = grad_ctxt.grad_state.GetRealValue(val); | |||||
| _external_values[val.name] = real_val; | _external_values[val.name] = real_val; | ||||
| return real_val;*/ | |||||
| return real_val; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ namespace Tensorflow.Operations | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | ||||
| { | { | ||||
| tf_with(tf.variable_scope("rnn"), scope => | |||||
| return tf_with(tf.variable_scope("rnn"), scope => | |||||
| { | { | ||||
| VariableScope varscope = scope; | VariableScope varscope = scope; | ||||
| var flat_input = nest.flatten(inputs_tensor); | var flat_input = nest.flatten(inputs_tensor); | ||||
| @@ -64,9 +64,12 @@ namespace Tensorflow.Operations | |||||
| swap_memory: swap_memory, | swap_memory: swap_memory, | ||||
| sequence_length: sequence_length, | sequence_length: sequence_length, | ||||
| dtype: dtype); | dtype: dtype); | ||||
| }); | |||||
| throw new NotImplementedException(""); | |||||
| if (!time_major) | |||||
| outputs = nest.map_structure(_transpose_batch_time, outputs); | |||||
| return (outputs, final_state); | |||||
| }); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -210,16 +213,28 @@ namespace Tensorflow.Operations | |||||
| var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t); | var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t); | ||||
| // Keras RNN cells only accept state as list, even if it's a single tensor. | // Keras RNN cells only accept state as list, even if it's a single tensor. | ||||
| // var is_keras_rnn_cell = _is_keras_rnn_cell(cell); | // var is_keras_rnn_cell = _is_keras_rnn_cell(cell); | ||||
| (Tensor, Tensor) a = (null, null); | |||||
| Tensor[] outputs = null; | |||||
| if (sequence_length != null) | if (sequence_length != null) | ||||
| throw new NotImplementedException("sequence_length != null"); | throw new NotImplementedException("sequence_length != null"); | ||||
| else | else | ||||
| a = cell.__call__(input_t_t, state: state1); | |||||
| outputs = cell.__call__(input_t_t, state: state1); | |||||
| var (output, new_state) = (outputs[0], outputs[1]); | |||||
| // Keras cells always wrap state as list, even if it's a single tensor. | |||||
| // if(is_keras_rnn_cell && len(new_state)) == 1 | |||||
| // Pack state if using state tuples | |||||
| outputs = nest.flatten2(output).Select(x => x as Tensor).ToArray(); | |||||
| return item; | |||||
| output_ta_t = zip(output_ta_t, outputs).Select(x => | |||||
| { | |||||
| var(ta, @out) = (x.Item1, x.Item2); | |||||
| return ta.write(item.time, @out); | |||||
| }).ToArray(); | |||||
| return new BodyItemInRnnWhileLoop(item.time + 1, output_ta_t, new_state); | |||||
| }; | }; | ||||
| control_flow_ops.while_loop( | |||||
| var while_loop_result = control_flow_ops.while_loop( | |||||
| cond: cond, | cond: cond, | ||||
| body: _time_step, | body: _time_step, | ||||
| loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), | loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), | ||||
| @@ -227,7 +242,18 @@ namespace Tensorflow.Operations | |||||
| maximum_iterations: time_steps, | maximum_iterations: time_steps, | ||||
| swap_memory: swap_memory); | swap_memory: swap_memory); | ||||
| throw new NotImplementedException(""); | |||||
| (_, TensorArray[] output_final_ta, Tensor final_state) = (while_loop_result.time, while_loop_result.output_ta_t, while_loop_result.state); | |||||
| // Unpack final output if not using output tuples. | |||||
| var final_outputs = output_final_ta.Select(ta => ta.stack()).ToArray(); | |||||
| // Restore some shape information | |||||
| foreach (var (output, output_size) in zip(final_outputs, flat_output_size)) | |||||
| { | |||||
| var shape = rnn_cell_impl._concat(new[] { const_time_steps, const_batch_size }, output_size, @static: true); | |||||
| output.set_shape(shape); | |||||
| } | |||||
| return (final_outputs[0], final_state); | |||||
| } | } | ||||
| private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape) | private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape) | ||||
| @@ -53,5 +53,34 @@ namespace Tensorflow.Operations | |||||
| return array_ops.concat(new[] { p, s }, 0); | return array_ops.concat(new[] { p, s }, 0); | ||||
| } | } | ||||
| } | } | ||||
| public static TensorShape _concat(int[] prefix, int suffix, bool @static = false) | |||||
| { | |||||
| var p = new TensorShape(prefix); | |||||
| var p_static = prefix; | |||||
| var p_tensor = p.is_fully_defined() ? constant_op.constant(p.as_list(), dtype: dtypes.int32) : null; | |||||
| var s_tensor_shape = new TensorShape(suffix); | |||||
| var s_static = s_tensor_shape.ndim > -1 ? | |||||
| s_tensor_shape.dims : | |||||
| null; | |||||
| var s_tensor = s_tensor_shape.is_fully_defined() ? | |||||
| constant_op.constant(s_tensor_shape.dims, dtype: dtypes.int32) : | |||||
| null; | |||||
| if (@static) | |||||
| { | |||||
| if (p_static is null) return null; | |||||
| var shape = new TensorShape(p_static).concatenate(s_static); | |||||
| return shape; | |||||
| } | |||||
| else | |||||
| { | |||||
| if (p is null || s_tensor is null) | |||||
| throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}"); | |||||
| // return array_ops.concat(new[] { p_tensor, s_tensor }, 0); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -52,6 +52,10 @@ namespace Tensorflow | |||||
| public void _set_control_flow_context(ControlFlowContext ctx) | public void _set_control_flow_context(ControlFlowContext ctx) | ||||
| { | { | ||||
| if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc")) | |||||
| { | |||||
| } | |||||
| _control_flow_context = ctx; | _control_flow_context = ctx; | ||||
| } | } | ||||
| @@ -59,5 +63,10 @@ namespace Tensorflow | |||||
| { | { | ||||
| return _control_flow_context; | return _control_flow_context; | ||||
| } | } | ||||
| public WhileContext GetWhileContext() | |||||
| { | |||||
| return _control_flow_context as WhileContext; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,17 +15,14 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Linq; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class Operation | public partial class Operation | ||||
| { | { | ||||
| // cache the mapping between managed and unmanaged op | |||||
| // some data is stored in managed instance, so when | |||||
| // create Operation by IntPtr, it will lost some data. | |||||
| private static Dictionary<IntPtr, Operation> OpInstances = new Dictionary<IntPtr, Operation>(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get operation by handle | /// Get operation by handle | ||||
| /// </summary> | /// </summary> | ||||
| @@ -33,9 +30,17 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Operation GetOperation(IntPtr handle) | public Operation GetOperation(IntPtr handle) | ||||
| { | { | ||||
| return OpInstances.ContainsKey(handle) ? | |||||
| OpInstances[handle] : | |||||
| new Operation(handle); | |||||
| var nodes = tf.get_default_graph()._nodes_by_name; | |||||
| foreach(var node in nodes.Values) | |||||
| { | |||||
| if (node is Operation op) | |||||
| { | |||||
| if (op == handle) | |||||
| return op; | |||||
| } | |||||
| } | |||||
| return null; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -106,7 +106,6 @@ namespace Tensorflow | |||||
| _control_flow_context = _graph._get_control_flow_context(); | _control_flow_context = _graph._get_control_flow_context(); | ||||
| // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | ||||
| OpInstances[_handle] = this; | |||||
| } | } | ||||
| /*public Operation(Graph g, string opType, string oper_name) | /*public Operation(Graph g, string opType, string oper_name) | ||||
| @@ -183,10 +182,12 @@ namespace Tensorflow | |||||
| // This will be set by self.inputs. | // This will be set by self.inputs. | ||||
| if (op_def == null) | if (op_def == null) | ||||
| op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
| if(node_def.Name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc") | |||||
| { | |||||
| } | |||||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
| _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | ||||
| _is_stateful = op_def.IsStateful; | _is_stateful = op_def.IsStateful; | ||||
| // Initialize self._outputs. | // Initialize self._outputs. | ||||
| @@ -202,8 +203,6 @@ namespace Tensorflow | |||||
| if (_handle != IntPtr.Zero) | if (_handle != IntPtr.Zero) | ||||
| _control_flow_post_processing(); | _control_flow_post_processing(); | ||||
| OpInstances[_handle] = this; | |||||
| } | } | ||||
| public void run(FeedItem[] feed_dict = null, Session session = null) | public void run(FeedItem[] feed_dict = null, Session session = null) | ||||
| @@ -183,7 +183,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Identity", name, new { input }); | var _op = _op_def_lib._apply_op_helper("Identity", name, new { input }); | ||||
| return _op.outputs[0]; | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor invert_permutation(Tensor x, string name = null) | public static Tensor invert_permutation(Tensor x, string name = null) | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Tensorflow.Operations; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class gen_control_flow_ops | public class gen_control_flow_ops | ||||
| @@ -148,18 +150,18 @@ namespace Tensorflow | |||||
| return new []{_op.outputs[0], _op.outputs[1]}; | return new []{_op.outputs[0], _op.outputs[1]}; | ||||
| } | } | ||||
| public static Tensor[] ref_merge(Tensor[] inputs, string name = null) | |||||
| public static MergeOutput ref_merge(Tensor[] inputs, string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); | var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); | ||||
| return _op.outputs; | |||||
| return new MergeOutput(_op.outputs); | |||||
| } | } | ||||
| public static Tensor[] merge(Tensor[] inputs, string name = null) | |||||
| public static MergeOutput merge(Tensor[] inputs, string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); | var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); | ||||
| return _op.outputs; | |||||
| return new MergeOutput(_op.outputs); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -259,5 +259,31 @@ namespace Tensorflow | |||||
| return _op.output; | return _op.output; | ||||
| } | } | ||||
| public static Tensor stack_v2(Tensor max_size, TF_DataType elem_type, string stack_name = "", | |||||
| string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("StackV2", name, new | |||||
| { | |||||
| max_size, | |||||
| elem_type, | |||||
| stack_name | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| public static Tensor stack_push_v2(Tensor handle, Tensor elem, bool swap_memory = false, | |||||
| string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("StackPushV2", name, new | |||||
| { | |||||
| handle, | |||||
| elem, | |||||
| swap_memory | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -282,7 +282,7 @@ namespace Tensorflow | |||||
| /// <param name="dy"></param> | /// <param name="dy"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor tanh_grad(Tensor y, Tensor dy, string name = "TanhGrad") | |||||
| public static Tensor tanh_grad(Tensor y, Tensor dy, string name = null) | |||||
| => _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; | => _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; | ||||
| public static Tensor floor(Tensor x, string name = null) | public static Tensor floor(Tensor x, string name = null) | ||||
| @@ -526,6 +526,14 @@ namespace Tensorflow.Util | |||||
| return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | ||||
| } | } | ||||
| public static Tensor map_structure2<T>(Func<T, Tensor> func, T structure) | |||||
| { | |||||
| var flat_structure = flatten(structure); | |||||
| var mapped_flat_structure = flat_structure.Select(func).ToList(); | |||||
| return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Same as map_structure, but with only one structure (no combining of multiple structures) | /// Same as map_structure, but with only one structure (no combining of multiple structures) | ||||
| /// </summary> | /// </summary> | ||||