diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 163192ee..8170ea6f 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -55,6 +55,9 @@ namespace Tensorflow * is more than one. **/ var grads = new Dictionary>>(); + Operation[] reachable_to_ops = null; + ControlFlowState loop_state = null; + Dictionary pending_count = null; tf_with(ops.name_scope(name, "gradients", values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => @@ -81,7 +84,7 @@ namespace Tensorflow var to_ops = ys.Select(x => x.op).ToList(); var from_ops = xs.Select(x => x.op).ToList(); var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); - var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); + (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); // Add the initial gradients for the ys. foreach (var (y, grad_y) in zip(ys, grad_ys)) @@ -120,126 +123,135 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); - if(op.name == "rnn/while/basic_rnn_cell/Tanh") + if(op.name == "rnn/while/Exit") { } _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); - //if (loop_state != null) - //loop_state.EnterGradWhileContext(op, before: true); - var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); - - Tensor[] in_grads = null; - var is_partitioned_call = _IsPartitionedCall(op); - var is_func_call = false; - var has_out_grads = out_grads.Exists(x => x != null); - if (has_out_grads && !stop_ops.Contains(op)) { - // A grad_fn must be defined, either as a function or as None - // for ops that do not have gradients. + if (loop_state != null) + loop_state.EnterGradWhileContext(op, before: true); + var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); + if (loop_state != null) + loop_state.ExitGradWhileContext(op, before: true); - Func grad_fn = null; - try - { - grad_fn = ops.get_gradient_function(op); - } - catch (LookupError) + Tensor[] in_grads = null; + var is_partitioned_call = _IsPartitionedCall(op); + var is_func_call = false; + var has_out_grads = out_grads.Exists(x => x != null); + if (has_out_grads && !stop_ops.Contains(op)) { - if (is_func_call) + // A grad_fn must be defined, either as a function or as None + // for ops that do not have gradients. + + Func grad_fn = null; + try { - if (is_partitioned_call) + grad_fn = ops.get_gradient_function(op); + } + catch (LookupError) + { + if (is_func_call) { + if (is_partitioned_call) + { + + } + else + { + } } else { - + throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); } } - else - { - throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); - } - } - if (loop_state != null) - loop_state.EnterGradWhileContext(op, before: false); + if (loop_state != null) + loop_state.EnterGradWhileContext(op, before: false); - if ((is_func_call || grad_fn != null) && has_out_grads) - { - // NOTE: If _AggregatedGrads didn't compute a value for the i'th - // output, it means that the cost does not depend on output[i], - // therefore dC/doutput[i] is 0. - foreach (var (i, out_grad) in enumerate(out_grads)) + if ((is_func_call || grad_fn != null) && has_out_grads) { - if (out_grad == null && - (grad_fn == null || _IsTrainable(op.outputs[i]))) + // NOTE: If _AggregatedGrads didn't compute a value for the i'th + // output, it means that the cost does not depend on output[i], + // therefore dC/doutput[i] is 0. + foreach (var (i, out_grad) in enumerate(out_grads)) { - // Only trainable outputs or outputs for a function call that - // will use SymbolicGradient get a zero gradient. Gradient - // functions should ignore the gradient for other outputs. - if (loop_state != null) - out_grads[i] = new List { loop_state.ZerosLike(op, i) }; - else - out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; + if (out_grad == null && + (grad_fn == null || _IsTrainable(op.outputs[i]))) + { + // Only trainable outputs or outputs for a function call that + // will use SymbolicGradient get a zero gradient. Gradient + // functions should ignore the gradient for other outputs. + if (loop_state != null) + out_grads[i] = new List { loop_state.ZerosLike(op, i) }; + else + out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; + } } - } - tf_with(ops.name_scope(op.name + "_grad"), scope1 => - { - if (grad_fn != null) + tf_with(ops.name_scope(op.name + "_grad"), scope1 => { - in_grads = _MaybeCompile(grad_scope, - op, - out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), - null, - grad_fn); - } - else - { - throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); - } - _VerifyGeneratedGradients(in_grads, op); - if (gate_gradients && in_grads.Count(x => x != null) > 1) - { - ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); - in_grads = control_flow_ops.tuple(in_grads); - } - }); + if (grad_fn != null) + { + in_grads = _MaybeCompile(grad_scope, + op, + out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), + null, + grad_fn); + } + else + { + throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); + } + _VerifyGeneratedGradients(in_grads, op); + if (gate_gradients && in_grads.Count(x => x != null) > 1) + { + ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); + in_grads = control_flow_ops.tuple(in_grads); + } + }); + } + else + { + // If no grad_fn is defined or none of out_grads is available, + // just propagate a list of None backwards. + in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; + } } else { - // If no grad_fn is defined or none of out_grads is available, - // just propagate a list of None backwards. in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; } - } - else - { - in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; - } - var inputs = _NonEagerInputs(op, xs).ToList(); - foreach (var (t_in, in_grad) in zip(inputs, in_grads)) - { - if (in_grad != null) + var inputs = _NonEagerInputs(op, xs).ToList(); + foreach (var (t_in, in_grad) in zip(inputs, in_grads)) { - if (!(in_grad is null) && - in_grad.Tag == null && // maybe a IndexedSlice - t_in.dtype != TF_DataType.TF_RESOURCE) + if (in_grad != null) { - in_grad.set_shape(t_in.TensorShape); - } + if (!(in_grad is null) && + in_grad.Tag == null && // maybe a IndexedSlice + t_in.dtype != TF_DataType.TF_RESOURCE) + { + in_grad.set_shape(t_in.TensorShape); + } - _SetGrad(grads, t_in, in_grad); + _SetGrad(grads, t_in, in_grad); + } } - } + if (loop_state != null) + loop_state.ExitGradWhileContext(op, before: false); + } + // Update pending count for the inputs of op and enqueue ready ops. _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); } }); + if (loop_state != null) + loop_state.PostProcessing(); return xs.Select(x => _GetGrad(grads, x)).ToArray(); } diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index d7cda786..39561990 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -50,10 +50,11 @@ namespace Tensorflow.Layers public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null) { - return __call__(inputs, training: training); + var results = __call__(inputs, training: training); + return (results[0], results[1]); } - public (Tensor, Tensor) __call__(Tensor inputs, + public Tensor[] __call__(Tensor inputs, Tensor training = null, Tensor state = null, VariableScope scope = null) @@ -73,7 +74,7 @@ namespace Tensorflow.Layers auxiliary_name_scope: false); } - (Tensor, Tensor) outputs = (null, null); + Tensor[] outputs = null; tf_with(scope_context_manager, scope2 => { _current_scope = scope2; diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 13182dfd..b8360939 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -151,27 +151,50 @@ namespace Tensorflow /// public static ControlFlowState MaybeCreateControlFlowState(List between_op_list, List between_ops, bool colocate_gradients_with_ops) { + var flag = new List(); ControlFlowState loop_state = null; - foreach (var op in between_op_list) + int pos = 0; + while(pos < between_op_list.Count) { + var op = between_op_list[pos]; if (IsLoopExit(op)) { - if(loop_state == null) + if (loop_state == null) { loop_state = new ControlFlowState(); } + if (colocate_gradients_with_ops) + ops.colocate_with(op); + loop_state.AddWhileContext(op, between_op_list, between_ops); } + pos++; } return loop_state; } public static bool IsLoopExit(Operation op) + => op.OpType == "Exit" || op.OpType == "RefExit"; + + public static bool IsLoopSwitch(Operation op) + { + if(IsSwitch(op)) + { + var ctxt = op._get_control_flow_context(); + return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op); + } + return false; + } + + public static bool IsCondSwitch(Operation op) { - return op.OpType == "Exit" || op.OpType == "RefExit"; + throw new NotImplementedException("IsCondSwitch"); } + public static bool IsSwitch(Operation op) + => op.type == "Switch" || op.type == "RefSwitch"; + public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null) { return tf_with(ops.name_scope(name, "tuple", tensors), scope => @@ -224,15 +247,10 @@ namespace Tensorflow //TODO: missing original code //if context.executing_eagerly(): // return output_tensor - var values = new List(); - values.AddRange(dependencies); - values.Add(output_tensor); - - return tf_with(ops.name_scope(name, "control_dependency", values), scope => + return tf_with(ops.name_scope(name, "control_dependency", new { dependencies, output_tensor }), scope => { name = scope; - // TODO: missing original code - //with ops.colocate_with(output_tensor): + ops.colocate_with(output_tensor); { return tf_with(ops.control_dependencies(dependencies), ctl => { @@ -431,6 +449,7 @@ namespace Tensorflow var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) + .Select(m => (Tensor)m) .ToArray(); var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); @@ -479,6 +498,7 @@ namespace Tensorflow var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new [] { pair.Item1, pair.Item2 })) + .Select(m => (Tensor)m) .ToArray(); var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); @@ -519,7 +539,7 @@ namespace Tensorflow /// inputs: The input tensors, at most one of which is available. /// A name for this operation (optional). /// - public static Tensor merge(Tensor[] inputs, string name = null) + public static MergeOutput merge(Tensor[] inputs, string name = null) { if (inputs.Any(x => x == null)) throw new ValueError($"At least one of the merge inputs is null: {inputs}"); @@ -529,7 +549,7 @@ namespace Tensorflow inputs = inputs.Select(inp => ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) .ToArray(); - return gen_control_flow_ops.merge(inputs, name)[0]; + return gen_control_flow_ops.merge(inputs, name); }); } @@ -602,7 +622,7 @@ namespace Tensorflow /// /// /// - public static Tensor while_loop(Func cond, Func body, TItem loop_vars, + public static TItem while_loop(Func cond, Func body, TItem loop_vars, TensorShape[] shape_invariants = null, int parallel_iterations = 10, bool back_prop = true, @@ -611,7 +631,7 @@ namespace Tensorflow Tensor maximum_iterations = null, bool return_same_structure = false) { - tf_with(ops.name_scope(name, "while", loop_vars), scope => + return tf_with(ops.name_scope(name, "while", loop_vars), scope => { if (loop_vars == null) throw new ValueError("No loop variables provided"); @@ -666,13 +686,11 @@ namespace Tensorflow var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants, return_same_structure); - if (maximum_iterations != null) - return results[1]; - else - return results[0]; + //if (maximum_iterations != null) + return results.Item; + //else + //return results; }); - - throw new NotImplementedException("while_loop"); } /// diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs index 9dcfb2e1..5f3bc15c 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Linq; using Tensorflow.Operations; using static Tensorflow.Binding; @@ -60,6 +61,45 @@ namespace Tensorflow public static bool IsSwitch(Operation op) { return op.type == "Switch" || op.type == "RefSwitch"; + } + + public static WhileContext GetWhileContext(Operation op) + => op.GetWhileContext(); + + public static bool IsCondSwitch(Operation op) + { + if (!IsSwitch(op)) + return false; + if (op.outputs == null || op.outputs.Length == 0) + return false; + + // Switch nodes are not part of the cond control flow context that they + // represent, so consider the consumers of its outputs to determine if it is + // cond switch or not. A switch is a cond switch iff all its consumers are in + // cond contexts. + var is_cond_switch = true; + foreach(var o in op.outputs) + { + foreach(var c in o.consumers()) + { + var ctxt = c._get_control_flow_context(); + if (IsLoopEnter(c)) + ctxt = ctxt.outer_context; + is_cond_switch = is_cond_switch &&(ctxt != null && ctxt.IsCondContext()); + } + } + + return is_cond_switch; + } + + public static bool IsLoopSwitch(Operation op) + { + if (IsSwitch(op)) + { + var ctxt = op._get_control_flow_context(); + return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op); + } + return false; } /// @@ -87,13 +127,64 @@ namespace Tensorflow valid = true; else { - throw new NotImplementedException(""); + var while_ctxt = GetContainingWhileContext(op_ctxt); + var input_while_ctxt = GetContainingWhileContext(input_ctxt); + + if (while_ctxt == null) + { + throw new NotImplementedException("CheckInputFromValidContext"); + } + else if (IsContainingContext(while_ctxt, input_while_ctxt)) + { + // input_op is in a while loop which contains op's while loop (or not in a + // while loop at all). + valid = true; + } + else if (while_ctxt.grad_state != null && + IsContainingContext(while_ctxt.grad_state.forward_context, + input_while_ctxt)) + { + valid = true; + } + else + throw new NotImplementedException("CheckInputFromValidContext"); } if (!valid) { - throw new NotImplementedException(""); + throw new NotImplementedException("CheckInputFromValidContext"); + } + } + + public static Operation GetLoopConstantEnter(Tensor value) + { + var id_ops = new string[] { "Switch", "RefSwitch", "Identity", "RefIdentity" }; + var op = value.op; + while (id_ops.Contains(op.type)) + op = op.inputs[0].op; + return IsLoopConstantEnter(op) ? op : null; + } + + public static bool IsContainingContext(WhileContext ctxt, WhileContext maybe_containing_ctxt) + { + while(ctxt != maybe_containing_ctxt) + { + if (ctxt == null) + return false; + ctxt = ctxt.outer_context as WhileContext; + } + return true; + } + + public static WhileContext GetContainingWhileContext(ControlFlowContext ctxt, ControlFlowContext stop_ctxt = null) + { + while (ctxt != null) + { + if (ctxt.IsWhileContext() || ctxt == stop_ctxt) + return ctxt as WhileContext; + ctxt = ctxt.outer_context; } + return null; } } } diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 17cd8a99..f158ffb1 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -159,6 +159,8 @@ namespace Tensorflow }); } + public static Tensor greater_equal(Tx x, Ty y, string name = null) + => gen_math_ops.greater_equal(x, y, name: name); public static Tensor equal(Tx x, Ty y, string name = null) => gen_math_ops.equal(x, y, name: name);