| @@ -41,11 +41,27 @@ namespace Tensorflow | |||||
| { | { | ||||
| var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper)); | var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper)); | ||||
| return _get_operation_by_name_unsafe(op_name); | return _get_operation_by_name_unsafe(op_name); | ||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates an `Operation` in this graph from the supplied TF_Operation. | |||||
| /// | |||||
| /// This method is like create_op() except the new Operation is constructed | |||||
| /// using `c_op`. The returned Operation will have `c_op` as its _c_op | |||||
| /// field.This is used to create Operation objects around TF_Operations created | |||||
| /// indirectly by the C API(e.g.by TF_ImportGraphDef, TF_FinishWhile). | |||||
| /// | |||||
| /// This function does not call Operation._control_flow_post_processing or | |||||
| /// Graph._control_dependencies_for_inputs (since the inputs may not be | |||||
| /// available yet). The caller is responsible for calling these methods. | |||||
| /// </summary> | |||||
| /// <param name="c_op">a wrapped TF_Operation</param> | |||||
| /// <param name="compute_device">(Optional.) If True, device functions will be executed | |||||
| /// to compute the device property of the Operation.</param> | |||||
| /// <returns>An `Operation` object.</returns> | |||||
| public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) | public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) | ||||
| { | { | ||||
| var ret = new Operation(c_op); | |||||
| var ret = new Operation(c_op, this); | |||||
| _add_op(ret); | _add_op(ret); | ||||
| var name_key = ret.name.ToLower(); | var name_key = ret.name.ToLower(); | ||||
| @@ -16,6 +16,7 @@ namespace Tensorflow.Operations | |||||
| /// The boolean tensor for the cond predicate | /// The boolean tensor for the cond predicate | ||||
| /// </summary> | /// </summary> | ||||
| private Tensor _pred; | private Tensor _pred; | ||||
| public Tensor pred => _pred; | public Tensor pred => _pred; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -23,11 +24,6 @@ namespace Tensorflow.Operations | |||||
| /// </summary> | /// </summary> | ||||
| private int _branch; | private int _branch; | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| private List<string> _values = new List<string>(); | |||||
| private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -66,72 +62,166 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Add the subgraph defined by fn() to the graph. | |||||
| /// Add `val` to the current context and its outer context recursively. | |||||
| /// </summary> | /// </summary> | ||||
| public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | |||||
| /// <param name="val"></param> | |||||
| public override Tensor AddValue(Tensor val) | |||||
| { | { | ||||
| // Add the subgraph defined by fn() to the graph. | |||||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var original_result = fn(); | |||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| Tensor result = null; | |||||
| if (_values.Contains(val.name)) | |||||
| { | |||||
| // Use the real value if it comes from outer context. This is needed in | |||||
| // particular for nested conds. | |||||
| if (_external_values.ContainsKey(val.name)) | |||||
| result = _external_values[val.name]; | |||||
| else | |||||
| result = val; | |||||
| } | |||||
| else | |||||
| { | |||||
| result = val; | |||||
| _values.Add(val.name); | |||||
| // TODO: _outer_context | |||||
| if (_outer_context != null) | |||||
| { | |||||
| result = _outer_context.AddValue(val); | |||||
| _values.Add(result.name); | |||||
| _external_values[result.name] = result; | |||||
| } | |||||
| // TODO: how to do 'with' here?? | |||||
| //with(ops.control_dependencies(null), ctrl => | |||||
| //{ | |||||
| var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred); | |||||
| result = new[]{r0, r1}[_branch]; | |||||
| if (_outer_context != null) | |||||
| _outer_context.AddInnerOp(result.op); | |||||
| //}); | |||||
| //TODO: port this chunck of missing code: | |||||
| /* | |||||
| if len(post_summaries) > len(pre_summaries): | |||||
| new_summaries = post_summaries[len(pre_summaries):] | |||||
| summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access | |||||
| summary_ref[:] = pre_summaries | |||||
| with ops.control_dependencies(new_summaries): | |||||
| if original_result is None: | |||||
| return no_op(), None | |||||
| else: | |||||
| original_result = nest.map_structure(array_ops.identity, | |||||
| original_result) | |||||
| */ | |||||
| if (original_result == null) | |||||
| return (original_result, null); | |||||
| result.op.graph.prevent_fetching(result.op); | |||||
| result.op._set_control_flow_context(this); | |||||
| switch (original_result) | |||||
| { | |||||
| case Tensor result: | |||||
| return (original_result, _BuildCondTensor(new[] { result.op })); | |||||
| case Operation[] results: | |||||
| return (original_result, _BuildCondTensor(results)); | |||||
| case float[] fv: | |||||
| // Mark Switch output as seen by this context and any outer contexts, | |||||
| // just like what we do for normal op outputs in _AddOpInternal() below. | |||||
| IControlFlowContext ctxt = this; | |||||
| while (ctxt != null) | |||||
| { | { | ||||
| var result = ops.convert_to_tensor(fv[0]); | |||||
| return (original_result, result ); | |||||
| ctxt.values.Add(result.name); | |||||
| ctxt = ctxt.outer_context; | |||||
| } | } | ||||
| default: | |||||
| return (original_result, null); | |||||
| _external_values[val.name] = result; | |||||
| } | } | ||||
| } | |||||
| public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | |||||
| return result; | |||||
| } | |||||
| /// <summary> | |||||
| /// Add the subgraph defined by fn() to the graph. | |||||
| /// </summary> | |||||
| public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | |||||
| { | |||||
| // Add the subgraph defined by fn() to the graph. | |||||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var original_result = fn(); | |||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| //TODO: port this chunck of missing code: | |||||
| /* | |||||
| if len(post_summaries) > len(pre_summaries): | |||||
| new_summaries = post_summaries[len(pre_summaries):] | |||||
| summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access | |||||
| summary_ref[:] = pre_summaries | |||||
| with ops.control_dependencies(new_summaries): | |||||
| if original_result is None: | |||||
| return no_op(), None | |||||
| else: | |||||
| original_result = nest.map_structure(array_ops.identity, | |||||
| original_result) | |||||
| */ | |||||
| if (original_result == null) | |||||
| return (original_result, null); | |||||
| switch (original_result) | |||||
| { | |||||
| case Tensor result: | |||||
| return (original_result, _BuildCondTensor(result)); | |||||
| case Operation op: | |||||
| return (original_result, _BuildCondTensor(op)); | |||||
| case float[] fv: | |||||
| { | |||||
| var result = ops.convert_to_tensor(fv[0]); | |||||
| return (original_result, _BuildCondTensor(result)); | |||||
| } | |||||
| default: | |||||
| return (original_result, null); | |||||
| } | |||||
| } | |||||
| public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | |||||
| { | |||||
| // Add the subgraph defined by fn() to the graph. | |||||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var original_result = fn(); | |||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| switch (original_result) | |||||
| { | |||||
| case Tensor[] results: | |||||
| return (original_result, results.Select(_BuildCondTensor).ToArray()); | |||||
| case Operation[] results: | |||||
| return (original_result, results.Select(_BuildCondTensor).ToArray()); | |||||
| case float[] fv: | |||||
| var result = ops.convert_to_tensor(fv[0]); | |||||
| return (original_result, new Tensor[] { result }); | |||||
| default: | |||||
| return (original_result, new Tensor[0]); | |||||
| } | |||||
| } | |||||
| private Tensor _BuildCondTensor(ITensorOrOperation v) | |||||
| { | |||||
| switch (v) | |||||
| { | |||||
| case Operation op: | |||||
| // Use pivot as the proxy for this op. | |||||
| return control_flow_ops.with_dependencies(new Operation[] { op }, _pivot); | |||||
| case Tensor t: | |||||
| return _ProcessOutputTensor(t); | |||||
| default: | |||||
| return _ProcessOutputTensor(ops.convert_to_tensor(v)); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Process an output tensor of a conditional branch. | |||||
| /// </summary> | |||||
| private Tensor _ProcessOutputTensor(Tensor val) | |||||
| { | { | ||||
| // Add the subgraph defined by fn() to the graph. | |||||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var original_result = fn(); | |||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| switch (original_result) | |||||
| var real_val = val; | |||||
| if (!_values.Contains(val.name)) | |||||
| { | { | ||||
| case Tensor[] results: | |||||
| return (original_result, new Tensor[] { _BuildCondTensor(results.Select(t=>t.op).ToArray())}); | |||||
| case Operation[] results: | |||||
| return (original_result, new Tensor[] { _BuildCondTensor (results) }); | |||||
| case float[] fv: | |||||
| var result = ops.convert_to_tensor(fv[0]); | |||||
| return (original_result, new Tensor[] { result }); | |||||
| default: | |||||
| return (original_result, new Tensor[0]); | |||||
| // Handle the special case of lambda: x | |||||
| _values.Add(val.name); | |||||
| if (_outer_context != null) | |||||
| { | |||||
| real_val = _outer_context.AddValue(val); | |||||
| _values.Add(real_val.name); | |||||
| _external_values[real_val.name] = real_val; | |||||
| } | |||||
| } | } | ||||
| else | |||||
| { | |||||
| Tensor external_val = null; | |||||
| if (_external_values.ContainsKey(val.name)) | |||||
| external_val = _external_values[val.name]; | |||||
| if (external_val != null) | |||||
| real_val = external_val; | |||||
| } | |||||
| return real_val; | |||||
| } | } | ||||
| private Tensor _BuildCondTensor(Operation[] v) | |||||
| public override void AddInnerOp(Operation resultOp) | |||||
| { | { | ||||
| // Use pivot as the proxy for this op. | |||||
| return control_flow_ops.with_dependencies(v, _pivot); | |||||
| throw new NotImplementedException(); | |||||
| } | } | ||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| @@ -29,6 +30,8 @@ namespace Tensorflow.Operations | |||||
| protected Tensor _pivot; | protected Tensor _pivot; | ||||
| protected Stack<IControlFlowContext> _context_stack; | protected Stack<IControlFlowContext> _context_stack; | ||||
| protected IControlFlowContext _outer_context; | |||||
| public ControlFlowContext() | public ControlFlowContext() | ||||
| { | { | ||||
| _context_stack = new Stack<IControlFlowContext>(); | _context_stack = new Stack<IControlFlowContext>(); | ||||
| @@ -69,23 +72,114 @@ namespace Tensorflow.Operations | |||||
| graph._set_control_flow_context(last_context); | graph._set_control_flow_context(last_context); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Add `op` to the current context. | |||||
| /// </summary> | |||||
| public void AddOp(Operation op) | public void AddOp(Operation op) | ||||
| { | { | ||||
| _AddOpInternal(op); | _AddOpInternal(op); | ||||
| } | } | ||||
| public IControlFlowContext outer_context { get { return _outer_context; } } | |||||
| public HashSet<string> values => _values; | |||||
| public virtual Tensor AddValue(Tensor val) | |||||
| { | |||||
| // to be overridden | |||||
| return null; | |||||
| } | |||||
| public virtual void AddInnerOp(Operation resultOp) | |||||
| { | |||||
| // to be overridden | |||||
| } | |||||
| protected HashSet<string> _values = new HashSet<string>(); | |||||
| /// <summary> | |||||
| /// Add `op` to the current context. | |||||
| /// </summary> | |||||
| protected virtual void _AddOpInternal(Operation op) | protected virtual void _AddOpInternal(Operation op) | ||||
| { | { | ||||
| if(op.inputs.Length == 0) | |||||
| if (op.inputs.Length == 0) | |||||
| { | { | ||||
| //If we're in a while loop, remove any control inputs from outside the | |||||
| // loop. | |||||
| _RemoveExternalControlEdges(op); | _RemoveExternalControlEdges(op); | ||||
| op._add_control_input(_pivot.op); | |||||
| if (!op.control_inputs.Any(input_op => OpInContext(input_op))) | |||||
| op._add_control_input(_pivot.op); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| // Make each input to 'op' available in this CondContext. If an input is | |||||
| // already part of this context there's nothing to do, but if it's | |||||
| // external, AddValue() will handle adding the appropriate Switch node and | |||||
| // other bookkeeping. | |||||
| for (int index = 0; index < op.inputs.Length; index++) | |||||
| { | |||||
| var x = op.inputs[index]; | |||||
| Tensor real_x = null; | |||||
| if (op.type == "Merge" && x.op.type == "NextIteration") | |||||
| { | |||||
| //# Edge case: if we're importing a while loop inside this CondContext, | |||||
| //# AddValue() will not correctly handle the NextIteration inputs to | |||||
| //# Merge node. The problem is that the NextIteration should also be | |||||
| //# part of this context, but if we're importing it won't have been | |||||
| //# processed and added to the context yet, so AddValue() will try to | |||||
| //# add a Switch which results in an invalid graph. Instead, we use the | |||||
| //# NextIteration input as-is here, and it will eventually be added to | |||||
| //# the context via AddOp(). | |||||
| real_x = x; | |||||
| } | |||||
| else | |||||
| { | |||||
| real_x = AddValue(x); | |||||
| } | |||||
| if (real_x != x) | |||||
| op._update_input(index, real_x); | |||||
| } | |||||
| // Remove any external control dependency on this op. | |||||
| _RemoveExternalControlEdges(op); | |||||
| // TODO: implement below code dependencies | |||||
| //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | |||||
| // op._add_control_input(_pivot.op); | |||||
| } | |||||
| // Mark op's outputs as seen by this context and any outer contexts. | |||||
| var output_names = op.outputs.Select(x => x.name).ToArray(); | |||||
| IControlFlowContext ctxt = this; | |||||
| while (ctxt != null) | |||||
| { | |||||
| foreach(var name in output_names) | |||||
| ctxt.values.Add(name); | |||||
| ctxt = ctxt.outer_context; | |||||
| } | |||||
| if (_outer_context != null || !control_flow_ops.IsLoopExit(op)) | |||||
| op.graph.prevent_fetching(op); | |||||
| if (_outer_context != null) | |||||
| _outer_context.AddInnerOp(op); | |||||
| } | |||||
| private bool OpInContext(Operation op) | |||||
| { | |||||
| return IsContainingContext(op._get_control_flow_context(), this); | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns true if `maybe_containing_ctxt` is or contains `ctxt`. | |||||
| /// </summary> | |||||
| public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) | |||||
| { | |||||
| while (ctxt != maybe_containing_ctxt) | |||||
| { | |||||
| if (ctxt == null) | |||||
| return false; | |||||
| ctxt = ctxt.outer_context; | |||||
| } | } | ||||
| } | |||||
| return true; | |||||
| } | |||||
| protected virtual void _RemoveExternalControlEdges(Operation op) | protected virtual void _RemoveExternalControlEdges(Operation op) | ||||
| { | { | ||||
| @@ -7,5 +7,9 @@ namespace Tensorflow | |||||
| public interface IControlFlowContext | public interface IControlFlowContext | ||||
| { | { | ||||
| void AddOp(Operation op); | void AddOp(Operation op); | ||||
| IControlFlowContext outer_context { get; } | |||||
| HashSet<string> values { get; } | |||||
| Tensor AddValue(Tensor val); | |||||
| void AddInnerOp(Operation resultOp); | |||||
| } | } | ||||
| } | } | ||||
| @@ -7,16 +7,20 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class Operation | public partial class Operation | ||||
| { | { | ||||
| private IControlFlowContext _control_flow_context; | |||||
| private IControlFlowContext _control_flow_context; | |||||
| /// <summary> | /// <summary> | ||||
| /// Add this op to its control flow context. | /// Add this op to its control flow context. | ||||
| /// | |||||
| /// This may add new ops and change this op's inputs. self.inputs must be | |||||
| /// available before calling this method. | |||||
| /// </summary> | /// </summary> | ||||
| public void _control_flow_post_processing() | public void _control_flow_post_processing() | ||||
| { | { | ||||
| foreach(var input_tensor in inputs) | foreach(var input_tensor in inputs) | ||||
| { | { | ||||
| //TODO: implement below code dependency | |||||
| //control_flow_util.CheckInputFromValidContext(this, input_tensor.op); | |||||
| } | } | ||||
| if (_control_flow_context != null) | if (_control_flow_context != null) | ||||
| @@ -62,16 +62,22 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public Operation(IntPtr handle) | |||||
| public Operation(IntPtr handle, Graph g=null) | |||||
| { | { | ||||
| if (handle == IntPtr.Zero) | if (handle == IntPtr.Zero) | ||||
| return; | return; | ||||
| _handle = handle; | _handle = handle; | ||||
| _graph = ops.get_default_graph(); | |||||
| _graph = g ?? ops.get_default_graph(); | |||||
| _outputs = new Tensor[NumOutputs]; | _outputs = new Tensor[NumOutputs]; | ||||
| for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
| _outputs[i] = new Tensor(this, i, OutputType(i)); | _outputs[i] = new Tensor(this, i, OutputType(i)); | ||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| _control_flow_context = graph._get_control_flow_context(); | |||||
| // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | |||||
| } | } | ||||
| public Operation(Graph g, string opType, string oper_name) | public Operation(Graph g, string opType, string oper_name) | ||||
| @@ -81,6 +87,10 @@ namespace Tensorflow | |||||
| _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | ||||
| c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | ||||
| _handle = c_api.TF_FinishOperation(_operDesc, status); | _handle = c_api.TF_FinishOperation(_operDesc, status); | ||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| _control_flow_context = graph._get_control_flow_context(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -258,6 +268,23 @@ namespace Tensorflow | |||||
| } | } | ||||
| return base.Equals(obj); | return base.Equals(obj); | ||||
| } | |||||
| /// <summary> | |||||
| /// Update the input to this operation at the given index. | |||||
| /// | |||||
| /// NOTE: This is for TF internal use only.Please don't use it. | |||||
| /// </summary> | |||||
| /// <param name="index">the index of the input to update.</param> | |||||
| /// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||||
| public void _update_input(int index, Tensor tensor) | |||||
| { | |||||
| throw new NotImplementedException("_update_input"); | |||||
| // TODO: implement below code dependencies | |||||
| //_assert_same_graph( tensor); | |||||
| //// Reset cached inputs. | |||||
| //_inputs_val = null; | |||||
| //c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -111,7 +111,7 @@ namespace Tensorflow | |||||
| return loop_state; | return loop_state; | ||||
| } | } | ||||
| private static bool IsLoopExit(Operation op) | |||||
| public static bool IsLoopExit(Operation op) | |||||
| { | { | ||||
| return op.OpType == "Exit" || op.OpType == "RefExit"; | return op.OpType == "Exit" || op.OpType == "RefExit"; | ||||
| } | } | ||||
| @@ -193,20 +193,49 @@ namespace Tensorflow | |||||
| return gen_array_ops.identity(data, name: name); | return gen_array_ops.identity(data, name: name); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Forwards `data` to an output determined by `pred`. | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="pred"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| /// <summary> | |||||
| /// Forwards `data` to an output determined by `pred`. | |||||
| /// If `pred` is false, the `data` input is forwarded to the first output. | |||||
| /// Otherwise, the data goes to the second output. | |||||
| /// | |||||
| /// This op handles `Tensor`s and `IndexedSlices`. | |||||
| /// </summary> | |||||
| /// <param name="data">The tensor to be forwarded to the appropriate output.</param> | |||||
| /// <param name="pred">A scalar that specifies which output port will receive data.</param> | |||||
| /// <param name="name"> A name for this operation (optional).</param> | |||||
| /// <returns> | |||||
| /// `(output_false, output_true)`: If `pred` is true, data will be forwarded to | |||||
| /// `output_true`, otherwise it goes to `output_false`. | |||||
| /// </returns> | |||||
| public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | ||||
| { | { | ||||
| data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | |||||
| data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | |||||
| // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | |||||
| // addresses the following scenario. | |||||
| // | |||||
| // Assume you execute Optimizer.apply_gradients() in a branch of a cond(). | |||||
| // | |||||
| // 1. The update op is created inside a `with ops.colocate(var):` block | |||||
| // | |||||
| // 2. Some tensor `data` is captured and a switch is created in a | |||||
| // `with ops.colocate_with(data):` block. | |||||
| // | |||||
| // with ops.colocate_with(var): | |||||
| // with ops.colocate_with(data): | |||||
| // op = ... | |||||
| // | |||||
| // var and data may be pinned to different devices, so we want to ops | |||||
| // created within ops.colocate_with(data) to ignore the existing stack. | |||||
| ops.colocate_with(data, ignore_existing: true); | ops.colocate_with(data, ignore_existing: true); | ||||
| return @switch(data, pred, name: name); | |||||
| { | |||||
| if (data is Tensor) | |||||
| { | |||||
| // TODO: ref_switch | |||||
| //if (data.dtype._is_ref_dtype) | |||||
| // return control_flow_ops.ref_switch(data, pred, name = name); | |||||
| } | |||||
| return @switch(data, pred, name: name); | |||||
| } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -483,6 +512,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | throw new NotImplementedException("ZerosLikeOutsideLoop"); | ||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -13,13 +13,35 @@ namespace Tensorflow | |||||
| var _op = _op_def_lib._apply_op_helper("NoOp", name, null); | var _op = _op_def_lib._apply_op_helper("NoOp", name, null); | ||||
| return _op; | return _op; | ||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Forwards `data` to the output port determined by `pred`. | |||||
| /// | |||||
| /// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, | |||||
| /// the data goes to `output_false`. | |||||
| /// | |||||
| /// See also `RefSwitch` and `Merge`. | |||||
| /// </summary> | |||||
| /// <param name="data">A `Tensor`. The tensor to be forwarded to the appropriate output.</param> | |||||
| /// <param name="pred">A `Tensor` of type `bool`. | |||||
| /// A scalar that specifies which output port will receive data. | |||||
| /// </param> | |||||
| /// <param name="name"> A name for the operation (optional).</param> | |||||
| /// <returns>A tuple of `Tensor` objects (output_false, output_true). | |||||
| /// | |||||
| /// output_false: A `Tensor`. Has the same type as `data`. | |||||
| /// output_true: A `Tensor`. Has the same type as `data`. | |||||
| /// </returns> | |||||
| public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) | public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); | var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); | ||||
| return (_op.outputs[0], _op.outputs[1]); | |||||
| var _result = (_op.outputs[0], _op.outputs[1]); | |||||
| var _inputs_flat = _op.inputs; | |||||
| var _attrs = ("T", _op.get_attr("T")); | |||||
| // TODO: missing original code | |||||
| //_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); | |||||
| return _result; | |||||
| } | } | ||||
| public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | ||||
| @@ -84,6 +84,10 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| control_flow_ops.cond(x < 10, true_fn, () => x); | control_flow_ops.cond(x < 10, true_fn, () => x); | ||||
| var op = g.get_operation_by_name("cond/myop"); | var op = g.get_operation_by_name("cond/myop"); | ||||
| tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true); | |||||
| tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
| self.assertIsNotNone(op); | self.assertIsNotNone(op); | ||||
| self.assertEqual(op.name, "cond/myop"); | self.assertEqual(op.name, "cond/myop"); | ||||
| self.assertEqual(op.type, "Identity"); | self.assertEqual(op.type, "Identity"); | ||||