| @@ -238,9 +238,67 @@ namespace Tensorflow.Operations | |||||
| return real_val; | return real_val; | ||||
| } | } | ||||
| public override void AddInnerOp(Operation resultOp) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| protected override void _AddOpInternal(Operation op) | |||||
| { | |||||
| if (op.inputs.Length == 0) | |||||
| { | |||||
| //If we're in a while loop, remove any control inputs from outside the | |||||
| // loop. | |||||
| _RemoveExternalControlEdges(op); | |||||
| if (!op.control_inputs.Any(input_op => OpInContext(input_op))) | |||||
| op._add_control_input(_pivot.op); | |||||
| } | |||||
| else | |||||
| { | |||||
| // Make each input to 'op' available in this CondContext. If an input is | |||||
| // already part of this context there's nothing to do, but if it's | |||||
| // external, AddValue() will handle adding the appropriate Switch node and | |||||
| // other bookkeeping. | |||||
| for (int index = 0; index < op.inputs.Length; index++) | |||||
| { | |||||
| var x = op.inputs[index]; | |||||
| Tensor real_x = null; | |||||
| if (op.type == "Merge" && x.op.type == "NextIteration") | |||||
| { | |||||
| //# Edge case: if we're importing a while loop inside this CondContext, | |||||
| //# AddValue() will not correctly handle the NextIteration inputs to | |||||
| //# Merge node. The problem is that the NextIteration should also be | |||||
| //# part of this context, but if we're importing it won't have been | |||||
| //# processed and added to the context yet, so AddValue() will try to | |||||
| //# add a Switch which results in an invalid graph. Instead, we use the | |||||
| //# NextIteration input as-is here, and it will eventually be added to | |||||
| //# the context via AddOp(). | |||||
| real_x = x; | |||||
| } | |||||
| else | |||||
| { | |||||
| real_x = AddValue(x); | |||||
| } | |||||
| if (real_x != x) | |||||
| op._update_input(index, real_x); | |||||
| } | |||||
| // Remove any external control dependency on this op. | |||||
| _RemoveExternalControlEdges(op); | |||||
| // TODO: implement below code dependencies | |||||
| //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | |||||
| // op._add_control_input(_pivot.op); | |||||
| } | |||||
| // Mark op's outputs as seen by this context and any outer contexts. | |||||
| var output_names = op.outputs.Select(x => x.name).ToArray(); | |||||
| IControlFlowContext ctxt = this; | |||||
| while (ctxt != null) | |||||
| { | |||||
| foreach (var name in output_names) | |||||
| ctxt.values.Add(name); | |||||
| ctxt = ctxt.outer_context; | |||||
| } | |||||
| if (_outer_context != null || !control_flow_ops.IsLoopExit(op)) | |||||
| op.graph.prevent_fetching(op); | |||||
| if (_outer_context != null) | |||||
| _outer_context.AddInnerOp(op); | |||||
| } | } | ||||
| public CondContextDef to_proto(string export_scope) | public CondContextDef to_proto(string export_scope) | ||||
| @@ -119,9 +119,14 @@ namespace Tensorflow.Operations | |||||
| return null; | return null; | ||||
| } | } | ||||
| public virtual void AddInnerOp(Operation resultOp) | |||||
| /// <summary> | |||||
| /// Notifies a scope about an operator added to an inner scope. | |||||
| /// </summary> | |||||
| /// <param name="op"></param> | |||||
| public virtual void AddInnerOp(Operation op) | |||||
| { | { | ||||
| // to be overridden | |||||
| if (_outer_context != null) | |||||
| _outer_context.AddInnerOp(op); | |||||
| } | } | ||||
| protected HashSet<string> _values = new HashSet<string>(); | protected HashSet<string> _values = new HashSet<string>(); | ||||
| @@ -131,68 +136,10 @@ namespace Tensorflow.Operations | |||||
| /// </summary> | /// </summary> | ||||
| protected virtual void _AddOpInternal(Operation op) | protected virtual void _AddOpInternal(Operation op) | ||||
| { | { | ||||
| if (op.inputs.Length == 0) | |||||
| { | |||||
| //If we're in a while loop, remove any control inputs from outside the | |||||
| // loop. | |||||
| _RemoveExternalControlEdges(op); | |||||
| if (!op.control_inputs.Any(input_op => OpInContext(input_op))) | |||||
| op._add_control_input(_pivot.op); | |||||
| } | |||||
| else | |||||
| { | |||||
| // Make each input to 'op' available in this CondContext. If an input is | |||||
| // already part of this context there's nothing to do, but if it's | |||||
| // external, AddValue() will handle adding the appropriate Switch node and | |||||
| // other bookkeeping. | |||||
| for (int index = 0; index < op.inputs.Length; index++) | |||||
| { | |||||
| var x = op.inputs[index]; | |||||
| Tensor real_x = null; | |||||
| if (op.type == "Merge" && x.op.type == "NextIteration") | |||||
| { | |||||
| //# Edge case: if we're importing a while loop inside this CondContext, | |||||
| //# AddValue() will not correctly handle the NextIteration inputs to | |||||
| //# Merge node. The problem is that the NextIteration should also be | |||||
| //# part of this context, but if we're importing it won't have been | |||||
| //# processed and added to the context yet, so AddValue() will try to | |||||
| //# add a Switch which results in an invalid graph. Instead, we use the | |||||
| //# NextIteration input as-is here, and it will eventually be added to | |||||
| //# the context via AddOp(). | |||||
| real_x = x; | |||||
| } | |||||
| else | |||||
| { | |||||
| real_x = AddValue(x); | |||||
| } | |||||
| if (real_x != x) | |||||
| op._update_input(index, real_x); | |||||
| } | |||||
| // Remove any external control dependency on this op. | |||||
| _RemoveExternalControlEdges(op); | |||||
| // TODO: implement below code dependencies | |||||
| //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | |||||
| // op._add_control_input(_pivot.op); | |||||
| } | |||||
| // Mark op's outputs as seen by this context and any outer contexts. | |||||
| var output_names = op.outputs.Select(x => x.name).ToArray(); | |||||
| IControlFlowContext ctxt = this; | |||||
| while (ctxt != null) | |||||
| { | |||||
| foreach(var name in output_names) | |||||
| ctxt.values.Add(name); | |||||
| ctxt = ctxt.outer_context; | |||||
| } | |||||
| if (_outer_context != null || !control_flow_ops.IsLoopExit(op)) | |||||
| op.graph.prevent_fetching(op); | |||||
| if (_outer_context != null) | |||||
| _outer_context.AddInnerOp(op); | |||||
| } | } | ||||
| private bool OpInContext(Operation op) | |||||
| protected bool OpInContext(Operation op) | |||||
| { | { | ||||
| return IsContainingContext(op._get_control_flow_context(), this); | return IsContainingContext(op._get_control_flow_context(), this); | ||||
| } | } | ||||
| @@ -23,7 +23,8 @@ namespace Tensorflow | |||||
| return with(ops.name_scope(name, "l2_normalize", new { x }), scope => | return with(ops.name_scope(name, "l2_normalize", new { x }), scope => | ||||
| { | { | ||||
| x = ops.convert_to_tensor(x, name: "x"); | x = ops.convert_to_tensor(x, name: "x"); | ||||
| var square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims: true); | |||||
| var sq = math_ops.square(x); | |||||
| var square_sum = math_ops.reduce_sum(sq, axis, keepdims: true); | |||||
| var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)); | var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)); | ||||
| return math_ops.multiply(x, x_inv_norm, name: name); | return math_ops.multiply(x, x_inv_norm, name: name); | ||||
| }); | }); | ||||
| @@ -360,6 +360,8 @@ namespace Tensorflow | |||||
| /// <returns>The default `Session` being used in the current thread.</returns> | /// <returns>The default `Session` being used in the current thread.</returns> | ||||
| public static Session get_default_session() | public static Session get_default_session() | ||||
| { | { | ||||
| if (tf.defaultSession == null) | |||||
| tf.defaultSession = tf.Session(); | |||||
| return tf.defaultSession; | return tf.defaultSession; | ||||
| } | } | ||||
| @@ -143,10 +143,7 @@ namespace TensorFlowNET.UnitTest | |||||
| // return self._eval_helper(tensors) | // return self._eval_helper(tensors) | ||||
| // else: | // else: | ||||
| { | { | ||||
| var sess = ops.get_default_session(); | |||||
| if (sess == null) | |||||
| sess = self.session(); | |||||
| with<Session>(sess, s => | |||||
| with(ops.get_default_session(), s => | |||||
| { | { | ||||
| var ndarray=tensor.eval(); | var ndarray=tensor.eval(); | ||||
| if (typeof(T) == typeof(double)) | if (typeof(T) == typeof(double)) | ||||