| @@ -9,7 +9,7 @@ | |||
| [](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | |||
| [](https://996.icu/#/en_US) | |||
| TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). <a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp_badge.png" width="200" height="200" align="right" /></a> | |||
| TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | |||
|  | |||
| @@ -26,6 +26,13 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr | |||
| ### How to use | |||
| | TensorFlow | tf 1.13 | tf 1.14 | tf 1.15 | tf 2.0 | | |||
| | ----------- | ------- | ------- | ------- | ------ | | |||
| | tf.net 0.12 | | x | | | | |||
| | tf.net 0.11 | x | x | | | | |||
| | tf.net 0.10 | x | x | | | | |||
| | tf.net 0.9 | x | | | | | |||
| Install TF.NET and TensorFlow binary through NuGet. | |||
| ```sh | |||
| ### install tensorflow C# binding | |||
| @@ -37,7 +37,7 @@ namespace Tensorflow | |||
| public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation | |||
| => control_flow_ops.group(inputs, name: name); | |||
| public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
| /*public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
| TensorShape shape_invariants = null, | |||
| int parallel_iterations = 10, | |||
| bool back_prop = true, | |||
| @@ -52,7 +52,7 @@ namespace Tensorflow | |||
| swap_memory: swap_memory, | |||
| name: name, | |||
| maximum_iterations: maximum_iterations, | |||
| return_same_structure: return_same_structure); | |||
| return_same_structure: return_same_structure);*/ | |||
| public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
| => ops.control_dependencies(control_inputs); | |||
| @@ -63,7 +63,7 @@ namespace Tensorflow | |||
| trainable: trainable, | |||
| name: name); | |||
| return layer.apply(inputs); | |||
| return layer.apply(inputs).Item1; | |||
| } | |||
| /// <summary> | |||
| @@ -117,7 +117,7 @@ namespace Tensorflow | |||
| trainable: trainable, | |||
| name: name); | |||
| return layer.apply(inputs, training: training); | |||
| return layer.apply(inputs, training: training).Item1; | |||
| } | |||
| /// <summary> | |||
| @@ -143,7 +143,7 @@ namespace Tensorflow | |||
| data_format: data_format, | |||
| name: name); | |||
| return layer.apply(inputs); | |||
| return layer.apply(inputs).Item1; | |||
| } | |||
| /// <summary> | |||
| @@ -179,7 +179,7 @@ namespace Tensorflow | |||
| kernel_initializer: kernel_initializer, | |||
| trainable: trainable); | |||
| return layer.apply(inputs); | |||
| return layer.apply(inputs).Item1; | |||
| } | |||
| /// <summary> | |||
| @@ -76,7 +76,7 @@ namespace Tensorflow | |||
| /// <param name="swap_memory"></param> | |||
| /// <param name="time_major"></param> | |||
| /// <returns>A pair (outputs, state)</returns> | |||
| public (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, | |||
| public (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs, | |||
| Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, | |||
| int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | |||
| => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, | |||
| @@ -134,7 +134,7 @@ namespace Tensorflow | |||
| => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); | |||
| public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK") | |||
| => gen_ops.in_top_k(predictions, targets, k, name); | |||
| => nn_ops.in_top_k(predictions, targets, k, name); | |||
| public Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) | |||
| => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); | |||
| @@ -30,6 +30,20 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public static partial class Binding | |||
| { | |||
| public static T2 get<T1, T2>(this Dictionary<T1, T2> dict, T1 key) | |||
| => key == null ? | |||
| default(T2) : | |||
| (dict.ContainsKey(key) ? dict[key] : default(T2)); | |||
| public static void add<T>(this IList<T> list, T element) | |||
| => list.Add(element); | |||
| public static void append<T>(this IList<T> list, T element) | |||
| => list.Add(element); | |||
| public static void extend<T>(this List<T> list, IEnumerable<T> elements) | |||
| => list.AddRange(elements); | |||
| private static string _tostring(object obj) | |||
| { | |||
| switch (obj) | |||
| @@ -81,6 +95,9 @@ namespace Tensorflow | |||
| throw new NotImplementedException("len() not implemented for type: " + a.GetType()); | |||
| } | |||
| public static T[] list<T>(IEnumerable<T> list) | |||
| => list.ToArray(); | |||
| public static IEnumerable<int> range(int end) | |||
| { | |||
| return Enumerable.Range(0, end); | |||
| @@ -165,6 +182,12 @@ namespace Tensorflow | |||
| yield return (t1[i], t2[i]); | |||
| } | |||
| public static IEnumerable<(T1, T2, T3)> zip<T1, T2, T3>(IList<T1> t1, IList<T2> t2, IList<T3> t3) | |||
| { | |||
| for (int i = 0; i < t1.Count; i++) | |||
| yield return (t1[i], t2[i], t3[i]); | |||
| } | |||
| public static IEnumerable<(T1, T2)> zip<T1, T2>(NDArray t1, NDArray t2) | |||
| where T1: unmanaged | |||
| where T2: unmanaged | |||
| @@ -203,6 +226,7 @@ namespace Tensorflow | |||
| yield return (i, values[i]); | |||
| } | |||
| [DebuggerStepThrough] | |||
| public static Dictionary<string, object> ConvertToDict(object dyn) | |||
| { | |||
| var dictionary = new Dictionary<string, object>(); | |||
| @@ -0,0 +1,32 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class c_api | |||
| { | |||
| /// <summary> | |||
| /// Specify the device for `desc`. Defaults to empty, meaning unconstrained. | |||
| /// </summary> | |||
| /// <param name="desc"></param> | |||
| /// <param name="device"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetDevice(IntPtr desc, string device); | |||
| } | |||
| } | |||
| @@ -45,7 +45,24 @@ namespace Tensorflow.Gradients | |||
| switch (op_ctxt) | |||
| { | |||
| case WhileContext cwhile: | |||
| throw new NotImplementedException("_SwitchGrad WhileContext"); | |||
| { | |||
| var merge_grad = grad_ctxt.grad_state.switch_map.get(op); | |||
| if (merge_grad != null) | |||
| { | |||
| if (grads[1] != null) | |||
| control_flow_ops._AddNextAndBackEdge(merge_grad, grads[1], | |||
| enforce_shape_invariant: false); | |||
| return new Tensor[] { null, null }; | |||
| } | |||
| else if (grads[0] != null) | |||
| { | |||
| merge_grad = merge(new[] { grads[0], grads[0] }, name: "b_switch")[0]; | |||
| grad_ctxt.grad_state.switch_map[op] = merge_grad; | |||
| return new Tensor[] { merge_grad, null }; | |||
| } | |||
| else | |||
| return new Tensor[] { null, null }; | |||
| } | |||
| case CondContext ccond: | |||
| { | |||
| var zero_grad = grads[1 - op_ctxt.branch]; | |||
| @@ -74,7 +91,7 @@ namespace Tensorflow.Gradients | |||
| /// <param name="inputs"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| internal static Tensor[] merge(Tensor[] inputs, string name = null) | |||
| internal static MergeOutput merge(Tensor[] inputs, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "Merge", inputs), scope => | |||
| { | |||
| @@ -146,7 +163,7 @@ namespace Tensorflow.Gradients | |||
| } | |||
| [RegisterGradient("RefMerge")] | |||
| public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | |||
| public static Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | |||
| { | |||
| return _MergeGrad(op, grads); | |||
| } | |||
| @@ -155,43 +172,32 @@ namespace Tensorflow.Gradients | |||
| /// Gradients for an exit op are calculated using an Enter op. | |||
| /// </summary> | |||
| [RegisterGradient("Exit")] | |||
| public Tensor[] _ExitGrad(Operation op, Tensor[] grads) | |||
| public static Tensor[] _ExitGrad(Operation op, Tensor[] grads) | |||
| { | |||
| throw new NotImplementedException("_ExitGrad"); | |||
| // graph = ops.get_default_graph() | |||
| //# pylint: disable=protected-access | |||
| // op_ctxt = op._get_control_flow_context() | |||
| // grad_ctxt = graph._get_control_flow_context() | |||
| // # pylint: enable=protected-access | |||
| // if not grad_ctxt.back_prop: | |||
| // # The flag `back_prop` is set by users to suppress gradient | |||
| // # computation for this loop. If the attribute `back_prop` is false, | |||
| // # no gradient computation. | |||
| // return None | |||
| var grad = grads[0]; | |||
| var graph = ops.get_default_graph(); | |||
| var op_ctxt = op._get_control_flow_context(); | |||
| var grad_ctxt = graph._get_control_flow_context() as WhileContext; | |||
| // The flag `back_prop` is set by users to suppress gradient | |||
| // computation for this loop. If the attribute `back_prop` is false, | |||
| // no gradient computation. | |||
| if (!grad_ctxt.back_prop) | |||
| return null; | |||
| // if op_ctxt.grad_state: | |||
| // raise TypeError("Second-order gradient for while loops not supported.") | |||
| if (op_ctxt.grad_state != null) | |||
| throw new TypeError("Second-order gradient for while loops not supported."); | |||
| // if isinstance(grad, ops.Tensor) : | |||
| // grad_ctxt.AddName(grad.name) | |||
| // else: | |||
| // if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)): | |||
| // raise TypeError("Type %s not supported" % type(grad)) | |||
| // grad_ctxt.AddName(grad.values.name) | |||
| // grad_ctxt.AddName(grad.indices.name) | |||
| // dense_shape = grad.dense_shape | |||
| // if dense_shape is not None: | |||
| // grad_ctxt.AddName(dense_shape.name) | |||
| // grad_ctxt.Enter() | |||
| // # pylint: disable=protected-access | |||
| // result = control_flow_ops._Enter( | |||
| // grad, grad_ctxt.name, is_constant=False, | |||
| // parallel_iterations=grad_ctxt.parallel_iterations, | |||
| // name="b_exit") | |||
| // # pylint: enable=protected-access | |||
| // grad_ctxt.loop_enters.append(result) | |||
| // grad_ctxt.Exit() | |||
| // return result | |||
| grad_ctxt.AddName(grad.name); | |||
| grad_ctxt.Enter(); | |||
| var result = control_flow_ops._Enter( | |||
| grad, grad_ctxt.name, is_constant: false, | |||
| parallel_iterations: grad_ctxt.parallel_iterations, | |||
| name: "b_exit"); | |||
| grad_ctxt.loop_enters.append(result); | |||
| grad_ctxt.Exit(); | |||
| return new[] { result }; | |||
| } | |||
| /// <summary> | |||
| @@ -200,15 +206,15 @@ namespace Tensorflow.Gradients | |||
| /// Note that the backprop next_iteration is added in switch grad. | |||
| /// </summary> | |||
| [RegisterGradient("NextIteration")] | |||
| public Tensor[] _NextIterationGrad(object _, Tensor[] grad) | |||
| public static Tensor[] _NextIterationGrad(Operation op, Tensor[] grads) | |||
| { | |||
| return grad; | |||
| return grads; | |||
| } | |||
| [RegisterGradient("RefNextIteration")] | |||
| public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad) | |||
| public static Tensor[] _RefNextIterationGrad(Operation op, Tensor[] grads) | |||
| { | |||
| return grad; | |||
| return grads; | |||
| } | |||
| /// <summary> | |||
| @@ -218,33 +224,31 @@ namespace Tensorflow.Gradients | |||
| /// For loop invariants, we need to add an accumulator loop. | |||
| /// </summary> | |||
| [RegisterGradient("Enter")] | |||
| public Tensor[] _EnterGrad(Tensor op, Tensor[] grad) | |||
| public static Tensor[] _EnterGrad(Operation op, Tensor[] grads) | |||
| { | |||
| throw new NotImplementedException("_EnterGrad"); | |||
| // graph = ops.get_default_graph() | |||
| //# pylint: disable=protected-access | |||
| // grad_ctxt = graph._get_control_flow_context() | |||
| // # pylint: enable=protected-access | |||
| // if not grad_ctxt.back_prop: | |||
| // # Skip gradient computation, if the attribute `back_prop` is false. | |||
| // return grad | |||
| // if grad_ctxt.grad_state is None: | |||
| // # Pass the gradient through if we are not in a gradient while context. | |||
| // return grad | |||
| // if op.get_attr("is_constant"): | |||
| // # Add a gradient accumulator for each loop invariant. | |||
| // if isinstance(grad, ops.Tensor) : | |||
| // result = grad_ctxt.AddBackpropAccumulator(op, grad) | |||
| // elif isinstance(grad, ops.IndexedSlices) : | |||
| // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) | |||
| // else: | |||
| // # TODO(yuanbyu, lukasr): Add support for SparseTensor. | |||
| // raise TypeError("Type %s not supported" % type(grad)) | |||
| // else: | |||
| // result = exit(grad) | |||
| // grad_ctxt.loop_exits.append(result) | |||
| // grad_ctxt.ExitResult([result]) | |||
| // return result | |||
| Tensor result = null; | |||
| var grad = grads[0]; | |||
| var graph = ops.get_default_graph(); | |||
| var grad_ctxt = graph._get_control_flow_context() as WhileContext; | |||
| if (!grad_ctxt.back_prop) | |||
| // Skip gradient computation, if the attribute `back_prop` is false. | |||
| return grads; | |||
| if (grad_ctxt.grad_state == null) | |||
| // Pass the gradient through if we are not in a gradient while context. | |||
| return grads; | |||
| if (op.get_attr<bool>("is_constant")) | |||
| { | |||
| // Add a gradient accumulator for each loop invariant. | |||
| result = grad_ctxt.AddBackpropAccumulator(op, grad); | |||
| } | |||
| else | |||
| { | |||
| result = control_flow_ops.exit(grad); | |||
| grad_ctxt.loop_exits.append(result); | |||
| grad_ctxt.ExitResult(new[] { result }); | |||
| } | |||
| return new Tensor[] { result }; | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -54,6 +55,9 @@ namespace Tensorflow | |||
| * is more than one. | |||
| **/ | |||
| var grads = new Dictionary<string, List<List<Tensor>>>(); | |||
| Operation[] reachable_to_ops = null; | |||
| ControlFlowState loop_state = null; | |||
| Dictionary<string, int> pending_count = null; | |||
| tf_with(ops.name_scope(name, "gradients", | |||
| values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => | |||
| @@ -80,8 +84,9 @@ namespace Tensorflow | |||
| var to_ops = ys.Select(x => x.op).ToList(); | |||
| var from_ops = xs.Select(x => x.op).ToList(); | |||
| var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | |||
| var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | |||
| (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | |||
| // Add the initial gradients for the ys. | |||
| foreach (var (y, grad_y) in zip(ys, grad_ys)) | |||
| _SetGrad(grads, y, grad_y); | |||
| @@ -103,6 +108,16 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| if(loop_state != null) | |||
| { | |||
| var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set); | |||
| foreach(var y in loop_exits) | |||
| { | |||
| //if(IsTrainable(y)) | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); | |||
| while (queue.Count > 0) | |||
| { | |||
| @@ -110,45 +125,48 @@ namespace Tensorflow | |||
| var op = queue.Dequeue(); | |||
| _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | |||
| //if (loop_state != null) | |||
| //loop_state.EnterGradWhileContext(op, before: true); | |||
| var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); | |||
| Tensor[] in_grads = null; | |||
| var is_partitioned_call = _IsPartitionedCall(op); | |||
| var is_func_call = false; | |||
| var has_out_grads = out_grads.Exists(x => x != null); | |||
| if (has_out_grads && !stop_ops.Contains(op)) | |||
| { | |||
| // A grad_fn must be defined, either as a function or as None | |||
| // for ops that do not have gradients. | |||
| if (loop_state != null) | |||
| loop_state.EnterGradWhileContext(op, before: true); | |||
| var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); | |||
| if (loop_state != null) | |||
| loop_state.ExitGradWhileContext(op, before: true); | |||
| Tensor[] in_grads = null; | |||
| Func<Operation, Tensor[], Tensor[]> grad_fn = null; | |||
| try | |||
| { | |||
| grad_fn = ops.get_gradient_function(op); | |||
| } | |||
| catch (LookupError) | |||
| var is_partitioned_call = _IsPartitionedCall(op); | |||
| var is_func_call = false; | |||
| var has_out_grads = out_grads.Exists(x => x != null); | |||
| if (has_out_grads && !stop_ops.Contains(op)) | |||
| { | |||
| if (is_func_call) | |||
| // A grad_fn must be defined, either as a function or as None | |||
| // for ops that do not have gradients. | |||
| try | |||
| { | |||
| if (is_partitioned_call) | |||
| grad_fn = ops.get_gradient_function(op); | |||
| } | |||
| catch (LookupError) | |||
| { | |||
| if (is_func_call) | |||
| { | |||
| if (is_partitioned_call) | |||
| { | |||
| } | |||
| else | |||
| { | |||
| } | |||
| } | |||
| else | |||
| { | |||
| throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); | |||
| } | |||
| } | |||
| // if (loop_state) | |||
| //loop_state.EnterGradWhileContext(op, before: false); | |||
| if (loop_state != null) | |||
| loop_state.EnterGradWhileContext(op, before: false); | |||
| if ((is_func_call || grad_fn != null) && has_out_grads) | |||
| { | |||
| @@ -164,7 +182,7 @@ namespace Tensorflow | |||
| // will use SymbolicGradient get a zero gradient. Gradient | |||
| // functions should ignore the gradient for other outputs. | |||
| if (loop_state != null) | |||
| ; | |||
| out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) }; | |||
| else | |||
| out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; | |||
| } | |||
| @@ -198,33 +216,34 @@ namespace Tensorflow | |||
| // just propagate a list of None backwards. | |||
| in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; | |||
| } | |||
| var inputs = _NonEagerInputs(op, xs).ToList(); | |||
| foreach (var (t_in, in_grad) in zip(inputs, in_grads)) | |||
| { | |||
| if (in_grad != null) | |||
| var inputs = _NonEagerInputs(op, xs).ToList(); | |||
| foreach (var (t_in, in_grad) in zip(inputs, in_grads)) | |||
| { | |||
| if (!(in_grad is null) && | |||
| in_grad.Tag == null && // maybe a IndexedSlice | |||
| t_in.dtype != TF_DataType.TF_RESOURCE) | |||
| if (in_grad != null) | |||
| { | |||
| in_grad.set_shape(t_in.TensorShape); | |||
| } | |||
| if (!(in_grad is null) && | |||
| in_grad.Tag == null && // maybe a IndexedSlice | |||
| t_in.dtype != TF_DataType.TF_RESOURCE) | |||
| { | |||
| in_grad.set_shape(t_in.TensorShape); | |||
| } | |||
| _SetGrad(grads, t_in, in_grad); | |||
| _SetGrad(grads, t_in, in_grad); | |||
| } | |||
| } | |||
| } | |||
| if (loop_state != null) | |||
| loop_state.ExitGradWhileContext(op, before: false); | |||
| } | |||
| // Update pending count for the inputs of op and enqueue ready ops. | |||
| _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); | |||
| } | |||
| }); | |||
| if (loop_state != null) | |||
| loop_state.PostProcessing(); | |||
| return xs.Select(x => _GetGrad(grads, x)).ToArray(); | |||
| } | |||
| @@ -275,7 +294,7 @@ namespace Tensorflow | |||
| /// <param name="colocate_gradients_with_ops"></param> | |||
| /// <param name="func_graphs"></param> | |||
| /// <param name="xs"></param> | |||
| private static (Operation[], Dictionary<string, int>, object) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) | |||
| private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) | |||
| { | |||
| // Mark reachable ops from from_ops. | |||
| var reached_ops = new List<Operation>(); | |||
| @@ -308,6 +327,7 @@ namespace Tensorflow | |||
| // 'loop_state' is None if there are no while loops. | |||
| var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); | |||
| // Initialize pending count for between ops. | |||
| var pending_count = new Dictionary<string, int>(); | |||
| foreach (var op in between_op_list) | |||
| { | |||
| @@ -342,7 +362,11 @@ namespace Tensorflow | |||
| grads[op.name] = op_grads; | |||
| } | |||
| var t_grads = op_grads[t.value_index]; | |||
| t_grads.Add(grad); | |||
| if (t_grads.Count > 0 && | |||
| control_flow_util.IsLoopSwitch(op)) | |||
| op_grads[t.value_index][0] = grad; | |||
| else | |||
| t_grads.Add(grad); | |||
| } | |||
| private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs) | |||
| @@ -351,7 +375,8 @@ namespace Tensorflow | |||
| yield return op.inputs[i]; | |||
| } | |||
| private static List<List<Tensor>> _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) | |||
| private static List<List<Tensor>> _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid, | |||
| ControlFlowState loop_state, int aggregation_method = 0) | |||
| { | |||
| var out_grads = _GetGrads(grads, op); | |||
| @@ -359,7 +384,10 @@ namespace Tensorflow | |||
| { | |||
| if (loop_state != null) | |||
| { | |||
| if (out_grads.Count > 1 && | |||
| out_grads[1].Count > 0 && | |||
| control_flow_util.IsLoopSwitch(op)) | |||
| continue; | |||
| } | |||
| // Aggregate multiple gradients, and convert [] to None. | |||
| @@ -550,7 +578,7 @@ namespace Tensorflow | |||
| Operation op, | |||
| Queue<Operation> queue, | |||
| Dictionary<string, int> pending_count, | |||
| object loop_state, | |||
| ControlFlowState loop_state, | |||
| Tensor[] xs) | |||
| { | |||
| foreach (var x in _NonEagerInputs(op, xs)) | |||
| @@ -564,14 +592,49 @@ namespace Tensorflow | |||
| if (loop_state != null && !ready) | |||
| { | |||
| ready = pending_count[x.op.name] > 0 && control_flow_util.IsLoopSwitch(x.op); | |||
| } | |||
| if (ready) | |||
| { | |||
| // if x is an exit without real gradient, defer processing them. | |||
| if (control_flow_util.IsLoopExit(x.op)) | |||
| { | |||
| var grad_state = loop_state.GetGradState(x.op, before: false); | |||
| grad_state.deferred_exits.append(x); | |||
| grad_state.pending_exits_count -= 1; | |||
| // We now have all the exits so process them. | |||
| if (grad_state.pending_exits_count == 0) | |||
| { | |||
| var has_not_none_grad = false; | |||
| foreach(var y in grad_state.deferred_exits) | |||
| { | |||
| if (_HasAnyNotNoneGrads(grads, y.op)) | |||
| { | |||
| has_not_none_grad = true; | |||
| queue.Enqueue(y.op); | |||
| } | |||
| else | |||
| grad_state.unused_exits.append(y); | |||
| } | |||
| if (has_not_none_grad) | |||
| { | |||
| // For an unused exit, if it has trainable outputs, backprop | |||
| // a zero gradient. Otherwise, just ignore it. | |||
| foreach (var y in grad_state.unused_exits) | |||
| { | |||
| if (IsTrainable(y)) | |||
| _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)); | |||
| queue.Enqueue(y.op); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // All exits are "unused" so use None as gradient. | |||
| foreach (var y in grad_state.unused_exits) | |||
| queue.Enqueue(y.op); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| @@ -581,6 +644,32 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private static bool IsTrainable(Tensor tensor) | |||
| { | |||
| var dtype = tensor.dtype.as_base_dtype(); | |||
| return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, | |||
| dtypes.complex64, dtypes.complex128, | |||
| dtypes.resource, dtypes.variant}.Contains(dtype); | |||
| } | |||
| /// <summary> | |||
| /// Return true if op has real gradient. | |||
| /// </summary> | |||
| /// <param name="grads"></param> | |||
| /// <param name="op"></param> | |||
| /// <returns></returns> | |||
| private static bool _HasAnyNotNoneGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op) | |||
| { | |||
| var out_grads = _GetGrads(grads, op); | |||
| foreach(var out_grad in out_grads) | |||
| { | |||
| if (out_grad.Exists(g => g != null)) | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | |||
| { | |||
| scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; | |||
| @@ -589,6 +678,9 @@ namespace Tensorflow | |||
| private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) | |||
| { | |||
| if (op.type == "While" || op.type == "StatelessWhile") | |||
| return; | |||
| if (grads.Count() != op.inputs._inputs.Count()) | |||
| throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + | |||
| $"inputs {op.inputs._inputs.Count()}"); | |||
| @@ -18,6 +18,7 @@ using System.Collections.Generic; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Linq; | |||
| using Tensorflow.Operations; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -75,7 +75,10 @@ namespace Tensorflow | |||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
| /// </summary> | |||
| /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | |||
| public partial class Graph : DisposableObject//, IEnumerable<Operation> | |||
| public partial class Graph : DisposableObject, | |||
| #if !SERIALIZABLE | |||
| IEnumerable<Operation> | |||
| #endif | |||
| { | |||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||
| public Dictionary<string, ITensorOrOperation> _nodes_by_name; | |||
| @@ -259,15 +262,11 @@ namespace Tensorflow | |||
| if (string.IsNullOrEmpty(name)) | |||
| name = op_type; | |||
| // If a names ends with a '/' it is a "name scope" and we use it as-is, | |||
| // after removing the trailing '/'. | |||
| name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); | |||
| var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | |||
| if (name.Contains("define_loss/bigger_box_loss/mul_13")) | |||
| { | |||
| } | |||
| var input_ops = inputs.Select(x => x.op).ToArray(); | |||
| var control_inputs = _control_dependencies_for_inputs(input_ops); | |||
| @@ -374,7 +373,11 @@ namespace Tensorflow | |||
| /// <returns>A string to be passed to `create_op()` that will be used | |||
| /// to name the operation being created.</returns> | |||
| public string unique_name(string name, bool mark_as_used = true) | |||
| { | |||
| { | |||
| if (name.EndsWith("basic_r_n_n_cell")) | |||
| { | |||
| } | |||
| if (!String.IsNullOrEmpty(_name_stack)) | |||
| name = _name_stack + "/" + name; | |||
| // For the sake of checking for names in use, we treat names as case | |||
| @@ -402,7 +405,7 @@ namespace Tensorflow | |||
| // Return the new name with the original capitalization of the given name. | |||
| name = $"{name}_{i-1}"; | |||
| } | |||
| } | |||
| return name; | |||
| } | |||
| @@ -524,17 +527,19 @@ namespace Tensorflow | |||
| } | |||
| return debugString;*/ | |||
| } | |||
| /*private IEnumerable<Operation> GetEnumerable() | |||
| } | |||
| #if !SERIALIZABLE | |||
| private IEnumerable<Operation> GetEnumerable() | |||
| => c_api_util.tf_operations(this); | |||
| IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | |||
| => GetEnumerable().GetEnumerator(); | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| => throw new NotImplementedException();*/ | |||
| => throw new NotImplementedException(); | |||
| #endif | |||
| public static implicit operator IntPtr(Graph graph) | |||
| { | |||
| return graph._handle; | |||
| @@ -16,6 +16,7 @@ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Operations; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -68,7 +69,9 @@ namespace Tensorflow | |||
| _new_stack = false; | |||
| } | |||
| _seen_nodes = new List<ITensorOrOperation>(); | |||
| _seen_nodes = new List<ITensorOrOperation>(); | |||
| _old_stack = null; | |||
| _old_control_flow_context = null; | |||
| } | |||
| public void add_op(ITensorOrOperation op) | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public interface ICanBeFlattened | |||
| { | |||
| object[] Flatten(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public interface IPackable<T> | |||
| { | |||
| T Pack(object[] sequences); | |||
| } | |||
| } | |||
| @@ -14,13 +14,14 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Runtime.InteropServices; | |||
| namespace Tensorflow.Sessions | |||
| namespace Tensorflow | |||
| { | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct TF_DeprecatedSession | |||
| /// <summary> | |||
| /// in order to limit function return value | |||
| /// is Tensor or TensorArray | |||
| /// </summary> | |||
| public interface ITensorOrTensorArray | |||
| { | |||
| Session session; | |||
| } | |||
| } | |||
| @@ -139,14 +139,14 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, Tensor training = null) | |||
| protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) | |||
| { | |||
| Tensor outputs = null; | |||
| if (fused) | |||
| { | |||
| outputs = _fused_batch_norm(inputs, training: training); | |||
| return outputs; | |||
| return new[] { outputs, outputs }; | |||
| } | |||
| throw new NotImplementedException("BatchNormalization call"); | |||
| @@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, Tensor training = null) | |||
| protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) | |||
| { | |||
| var outputs = _convolution_op.__call__(inputs, kernel); | |||
| if (use_bias) | |||
| @@ -124,9 +124,9 @@ namespace Tensorflow.Keras.Layers | |||
| } | |||
| if (activation != null) | |||
| return activation.Activate(outputs); | |||
| outputs = activation.Activate(outputs); | |||
| return outputs; | |||
| return new[] { outputs, outputs }; | |||
| } | |||
| } | |||
| } | |||
| @@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, Tensor training = null) | |||
| protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) | |||
| { | |||
| Tensor outputs = null; | |||
| var rank = inputs.rank; | |||
| @@ -88,9 +88,9 @@ namespace Tensorflow.Keras.Layers | |||
| if (use_bias) | |||
| outputs = tf.nn.bias_add(outputs, bias); | |||
| if (activation != null) | |||
| return activation.Activate(outputs); | |||
| outputs = activation.Activate(outputs); | |||
| return outputs; | |||
| return new[] { outputs, outputs }; | |||
| } | |||
| } | |||
| } | |||
| @@ -50,14 +50,14 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, Tensor training = null) | |||
| protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) | |||
| { | |||
| var dtype = inputs.dtype; | |||
| if (dtype != tf.int32 && dtype != tf.int64) | |||
| inputs = math_ops.cast(inputs, tf.int32); | |||
| var @out = embedding_ops.embedding_lookup(embeddings, inputs); | |||
| return @out; | |||
| return new[] { @out, @out }; | |||
| } | |||
| } | |||
| } | |||
| @@ -52,6 +52,7 @@ namespace Tensorflow.Keras.Layers | |||
| protected InputSpec input_spec; | |||
| protected bool supports_masking; | |||
| protected List<VariableV1> _trainable_weights; | |||
| protected List<VariableV1> _non_trainable_weights; | |||
| private string _name; | |||
| public string name => _name; | |||
| protected string _base_name; | |||
| @@ -84,6 +85,7 @@ namespace Tensorflow.Keras.Layers | |||
| _init_set_name(name); | |||
| _trainable_weights = new List<VariableV1>(); | |||
| _non_trainable_weights = new List<VariableV1>(); | |||
| _compute_previous_mask = false; | |||
| _updates = new List<Operation>(); | |||
| @@ -101,13 +103,14 @@ namespace Tensorflow.Keras.Layers | |||
| _inbound_nodes = new List<Node>(); | |||
| } | |||
| public Tensor __call__(Tensor[] inputs, | |||
| public Tensor[] __call__(Tensor[] inputs, | |||
| Tensor training = null, | |||
| Tensor state = null, | |||
| VariableScope scope = null) | |||
| { | |||
| var input_list = inputs; | |||
| var input = inputs[0]; | |||
| Tensor outputs = null; | |||
| Tensor[] outputs = null; | |||
| // We will attempt to build a TF graph if & only if all inputs are symbolic. | |||
| // This is always the case in graph mode. It can also be the case in eager | |||
| @@ -139,7 +142,10 @@ namespace Tensorflow.Keras.Layers | |||
| // overridden). | |||
| _maybe_build(inputs[0]); | |||
| outputs = call(inputs[0], training: training); | |||
| outputs = call(inputs[0], | |||
| training: training, | |||
| state: state); | |||
| (input, outputs) = _set_connectivity_metadata_(input, outputs); | |||
| _handle_activity_regularization(inputs[0], outputs); | |||
| _set_mask_metadata(inputs[0], outputs, null); | |||
| @@ -149,13 +155,13 @@ namespace Tensorflow.Keras.Layers | |||
| return outputs; | |||
| } | |||
| private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs) | |||
| private (Tensor, Tensor[]) _set_connectivity_metadata_(Tensor inputs, Tensor[] outputs) | |||
| { | |||
| //_add_inbound_node(input_tensors: inputs, output_tensors: outputs); | |||
| return (inputs, outputs); | |||
| } | |||
| private void _handle_activity_regularization(Tensor inputs, Tensor outputs) | |||
| private void _handle_activity_regularization(Tensor inputs, Tensor[] outputs) | |||
| { | |||
| //if(_activity_regularizer != null) | |||
| { | |||
| @@ -163,7 +169,7 @@ namespace Tensorflow.Keras.Layers | |||
| } | |||
| } | |||
| private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask) | |||
| private void _set_mask_metadata(Tensor inputs, Tensor[] outputs, Tensor previous_mask) | |||
| { | |||
| } | |||
| @@ -173,9 +179,9 @@ namespace Tensorflow.Keras.Layers | |||
| return null; | |||
| } | |||
| protected virtual Tensor call(Tensor inputs, Tensor training = null) | |||
| protected virtual Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) | |||
| { | |||
| return inputs; | |||
| throw new NotImplementedException(""); | |||
| } | |||
| protected virtual string _name_scope() | |||
| @@ -233,7 +239,10 @@ namespace Tensorflow.Keras.Layers | |||
| initializer: initializer, | |||
| trainable: trainable.Value); | |||
| //backend.track_variable(variable); | |||
| _trainable_weights.Add(variable); | |||
| if (trainable == true) | |||
| _trainable_weights.Add(variable); | |||
| else | |||
| _non_trainable_weights.Add(variable); | |||
| return variable; | |||
| } | |||
| @@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.input_spec = new InputSpec(ndim: 4); | |||
| } | |||
| protected override Tensor call(Tensor inputs, Tensor training = null) | |||
| protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) | |||
| { | |||
| int[] pool_shape; | |||
| if (data_format == "channels_last") | |||
| @@ -64,7 +64,7 @@ namespace Tensorflow.Keras.Layers | |||
| padding: padding.ToUpper(), | |||
| data_format: conv_utils.convert_data_format(data_format, 4)); | |||
| return outputs; | |||
| return new[] { outputs, outputs }; | |||
| } | |||
| } | |||
| } | |||
| @@ -43,17 +43,20 @@ namespace Tensorflow.Layers | |||
| // Avoid an incorrect lint error | |||
| _trainable_weights = new List<VariableV1>(); | |||
| _non_trainable_weights = new List<VariableV1>(); | |||
| this.built = false; | |||
| _keras_style = false; | |||
| } | |||
| public virtual Tensor apply(Tensor inputs, Tensor training = null) | |||
| public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null) | |||
| { | |||
| return __call__(inputs, training: training); | |||
| var results = __call__(inputs, training: training); | |||
| return (results[0], results[1]); | |||
| } | |||
| public Tensor __call__(Tensor inputs, | |||
| public Tensor[] __call__(Tensor inputs, | |||
| Tensor training = null, | |||
| Tensor state = null, | |||
| VariableScope scope = null) | |||
| { | |||
| _set_scope(scope); | |||
| @@ -71,12 +74,14 @@ namespace Tensorflow.Layers | |||
| auxiliary_name_scope: false); | |||
| } | |||
| Tensor outputs = null; | |||
| Tensor[] outputs = null; | |||
| tf_with(scope_context_manager, scope2 => | |||
| { | |||
| _current_scope = scope2; | |||
| // Actually call layer | |||
| outputs = base.__call__(new Tensor[] { inputs }, training: training); | |||
| outputs = base.__call__(new Tensor[] { inputs }, | |||
| training: training, | |||
| state: state); | |||
| }); | |||
| @@ -121,6 +126,11 @@ namespace Tensorflow.Layers | |||
| Graph init_graph = null; | |||
| VariableV1[] existing_variables = null; | |||
| if (synchronization == VariableSynchronization.OnRead) | |||
| trainable = false; | |||
| else if (!trainable.HasValue) | |||
| trainable = true; | |||
| if (default_graph.building_function) | |||
| { | |||
| throw new NotImplementedException("add_weight"); | |||
| @@ -16,18 +16,23 @@ | |||
| using System; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public class BasicRNNCell : LayerRNNCell | |||
| public class BasicRnnCell : LayerRnnCell | |||
| { | |||
| int _num_units; | |||
| Func<Tensor, string, Tensor> _activation; | |||
| public override int state_size => _num_units; | |||
| public override int output_size => _num_units; | |||
| public VariableV1 _kernel; | |||
| string _WEIGHTS_VARIABLE_NAME = "kernel"; | |||
| public VariableV1 _bias; | |||
| string _BIAS_VARIABLE_NAME = "bias"; | |||
| public BasicRNNCell(int num_units, | |||
| public BasicRnnCell(int num_units, | |||
| Func<Tensor, string, Tensor> activation = null, | |||
| bool? reuse = null, | |||
| string name = null, | |||
| @@ -44,5 +49,31 @@ namespace Tensorflow | |||
| else | |||
| _activation = activation; | |||
| } | |||
| protected override void build(TensorShape inputs_shape) | |||
| { | |||
| var input_depth = inputs_shape.dims[inputs_shape.ndim - 1]; | |||
| _kernel = add_weight( | |||
| _WEIGHTS_VARIABLE_NAME, | |||
| shape: new[] { input_depth + _num_units, _num_units }); | |||
| _bias = add_weight( | |||
| _BIAS_VARIABLE_NAME, | |||
| shape: new[] { _num_units }, | |||
| initializer: tf.zeros_initializer); | |||
| built = true; | |||
| } | |||
| protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) | |||
| { | |||
| // Most basic RNN: output = new_state = act(W * input + U * state + B). | |||
| var concat = array_ops.concat(new[] { inputs, state }, 1); | |||
| var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable); | |||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); | |||
| var output = _activation(gate_inputs, null); | |||
| return new[] { output, output }; | |||
| } | |||
| } | |||
| } | |||
| @@ -19,6 +19,8 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| using static Tensorflow.ControlFlowContextDef; | |||
| using static Tensorflow.Binding; | |||
| using util = Tensorflow.control_flow_util; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| @@ -72,6 +74,7 @@ namespace Tensorflow.Operations | |||
| public ControlFlowContext() | |||
| { | |||
| _context_stack = new Stack<ControlFlowContext>(); | |||
| _external_values = new Dictionary<string, ITensorOrOperation>(); | |||
| } | |||
| public string name { get => _name; } | |||
| @@ -134,27 +137,6 @@ namespace Tensorflow.Operations | |||
| graph._set_control_flow_context(this); | |||
| } | |||
| protected virtual Tensor _Enter(Tensor data, string frame_name, | |||
| bool is_constant = false, | |||
| int parallel_iterations = 10, | |||
| bool use_ref = true, | |||
| bool use_input_shape = true, | |||
| string name = null) | |||
| { | |||
| Tensor result; | |||
| data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||
| if (data.dtype.is_ref_dtype() && use_ref) | |||
| throw new NotImplementedException("_Enter"); | |||
| else | |||
| result = gen_control_flow_ops.enter( | |||
| data, frame_name, is_constant, parallel_iterations, name: name); | |||
| if (use_input_shape) | |||
| result.set_shape(data.TensorShape); | |||
| return result; | |||
| } | |||
| /// <summary> | |||
| /// Exit this control flow context. | |||
| /// </summary> | |||
| @@ -165,10 +147,18 @@ namespace Tensorflow.Operations | |||
| graph._set_control_flow_context(last_context); | |||
| } | |||
| public void ExitResult(Tensor[] result) | |||
| { | |||
| if(_outer_context != null) | |||
| { | |||
| throw new NotImplementedException("ExitResult"); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Add `op` to the current context. | |||
| /// </summary> | |||
| public void AddOp(Operation op) | |||
| public virtual void AddOp(Operation op) | |||
| { | |||
| _AddOpInternal(op); | |||
| } | |||
| @@ -180,12 +170,22 @@ namespace Tensorflow.Operations | |||
| public virtual bool back_prop => throw new NotImplementedException("abstract method"); | |||
| /// <summary> | |||
| /// Add `val` to the current context and its outer context recursively. | |||
| /// </summary> | |||
| /// <param name="val"></param> | |||
| /// <returns></returns> | |||
| public virtual Tensor AddValue(Tensor val) | |||
| { | |||
| // to be overridden | |||
| return null; | |||
| } | |||
| public void AddName(string name) | |||
| { | |||
| _values.Add(name); | |||
| } | |||
| /// <summary> | |||
| /// Notifies a scope about an operator added to an inner scope. | |||
| /// </summary> | |||
| @@ -203,7 +203,20 @@ namespace Tensorflow.Operations | |||
| /// </summary> | |||
| protected virtual void _AddOpInternal(Operation op) | |||
| { | |||
| if(op == null) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| else | |||
| { | |||
| foreach(var index in range(len(op.inputs))) | |||
| { | |||
| var x = op.inputs[index]; | |||
| var real_x = AddValue(x); | |||
| if (real_x != x) | |||
| op._update_input(index, real_x); | |||
| } | |||
| } | |||
| } | |||
| protected bool OpInContext(Operation op) | |||
| @@ -230,9 +243,36 @@ namespace Tensorflow.Operations | |||
| throw new NotImplementedException("_IsInOuterContext"); | |||
| } | |||
| protected virtual void _RemoveExternalControlEdges(Operation op) | |||
| /// <summary> | |||
| /// Remove any external control dependency on this op. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| protected virtual (Operation[], Operation[]) _RemoveExternalControlEdges(Operation op) | |||
| { | |||
| var internal_control_inputs = op.control_inputs; | |||
| var while_ctxt = GetWhileContext(); | |||
| var internal_control_inputs = new List<Operation>(); | |||
| // A control input of `op` is internal if it is in the same while | |||
| // loop context as the enclosing while loop context of self. | |||
| if (while_ctxt == null) | |||
| { | |||
| internal_control_inputs = op.control_inputs.ToList(); | |||
| } | |||
| else | |||
| { | |||
| foreach(Operation x in op.control_inputs) | |||
| { | |||
| var ctxt = util.GetOutputContext(x); | |||
| if (ctxt != null && ctxt.GetWhileContext() == while_ctxt) | |||
| internal_control_inputs.append(x); | |||
| } | |||
| } | |||
| var external_control_inputs = new List<Operation>(); | |||
| if (len(internal_control_inputs) != len(op.control_inputs)) | |||
| throw new NotImplementedException(""); | |||
| return (internal_control_inputs.ToArray(), external_control_inputs.ToArray()); | |||
| } | |||
| /// <summary> | |||
| @@ -264,6 +304,14 @@ namespace Tensorflow.Operations | |||
| throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}"); | |||
| } | |||
| public virtual bool IsWhileContext() | |||
| { | |||
| throw new NotImplementedException("IsWhileContext"); | |||
| } | |||
| public virtual bool IsCondContext() | |||
| => false; | |||
| public object to_proto() | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -14,6 +14,12 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Linq; | |||
| using System.Collections.Generic; | |||
| using util = Tensorflow.control_flow_util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Operations.ControlFlows | |||
| { | |||
| /// <summary> | |||
| @@ -21,6 +27,7 @@ namespace Tensorflow.Operations.ControlFlows | |||
| /// </summary> | |||
| public class ControlFlowState | |||
| { | |||
| Dictionary<ControlFlowContext, GradLoopState> _map; | |||
| //class ControlFlowState(object): | |||
| // """Maintain the mapping from the loops to their grad states.""" | |||
| @@ -40,57 +47,74 @@ namespace Tensorflow.Operations.ControlFlows | |||
| // return self._map.get(forward_ctxt) | |||
| // return None | |||
| // def ProcessUnusedLoopExits(self, pending_count, to_ops_set): | |||
| // """Process all the "unused" loop exits. | |||
| // The "unused" exits of the loops are added to `unused_exits`. An exit is | |||
| // unused if its pending_count is 0. If there is an exit with real gradient, | |||
| // all these deferred exits will enter the backprop loop with zero gradient. | |||
| // Otherwise, they will enter the backprop loop with None. As an example, | |||
| // people often write: | |||
| // ```python | |||
| // v1, _ = tf.while_loop(p, b, [x1, x2]) | |||
| // result = gradients(v1, x1) | |||
| // ``` | |||
| // The exit node for x2 is not included by the betweenness analysis. But we | |||
| // need to backprop x2 if x2 is involved in computing v1. | |||
| // Args: | |||
| // pending_count: The number of backprop inputs for every op. | |||
| // to_ops_set: The set of ops for ys in gradients(ys, xs) | |||
| // Returns: | |||
| // The set of unused loop exits that we know at this point we need | |||
| // to backprop. | |||
| // """ | |||
| // loop_exits = [] | |||
| // for grad_state in self._map.values(): | |||
| // for y in grad_state.forward_loop_exits: | |||
| // if pending_count[y.op] == 0: | |||
| // grad_state.pending_exits_count -= 1 | |||
| // if y.op not in to_ops_set: | |||
| // grad_state.unused_exits.append(y) | |||
| // if grad_state.pending_exits_count == 0: | |||
| // loop_exits.extend(grad_state.unused_exits) | |||
| // # Need to include Enters in backprop for higher-order gradients. | |||
| // for y in grad_state.forward_context.loop_enters: | |||
| // if pending_count[y.op] == 0: | |||
| // pending_count[y.op] = 1 | |||
| // return loop_exits | |||
| // def EnterGradWhileContext(self, op, before): | |||
| // """Enter the WhileContext for gradient computation.""" | |||
| // grad_state = self.GetGradState(op, before) | |||
| // if grad_state: | |||
| // grad_state.grad_context.Enter() | |||
| // def ExitGradWhileContext(self, op, before): | |||
| // """Exit the WhileContext for gradient computation.""" | |||
| // grad_state = self.GetGradState(op, before) | |||
| // if grad_state: | |||
| // grad_state.grad_context.Exit() | |||
| public ControlFlowState() | |||
| { | |||
| _map = new Dictionary<ControlFlowContext, GradLoopState>(); | |||
| } | |||
| /// <summary> | |||
| /// Return the grad state for this op if it's in a forward loop context. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <param name="before"></param> | |||
| /// <returns></returns> | |||
| public GradLoopState GetGradState(Operation op, bool before) | |||
| { | |||
| ControlFlowContext forward_ctxt = null; | |||
| if (before && util.IsLoopExit(op)) | |||
| { | |||
| forward_ctxt = op._get_control_flow_context(); | |||
| forward_ctxt = forward_ctxt.outer_context; | |||
| if (forward_ctxt != null) | |||
| forward_ctxt = forward_ctxt.GetWhileContext(); | |||
| } | |||
| else | |||
| forward_ctxt = util.GetWhileContext(op); | |||
| if (forward_ctxt != null) | |||
| return _map.get(forward_ctxt); | |||
| return null; | |||
| } | |||
| public Tensor[] ProcessUnusedLoopExits(Dictionary<string, int> pending_count, List<Operation> to_ops_set) | |||
| { | |||
| var loop_exits = new List<Tensor>(); | |||
| foreach(var grad_state in _map.Values) | |||
| { | |||
| foreach(var y in grad_state.forward_loop_exits) | |||
| { | |||
| if(!pending_count.ContainsKey(y.op.name)) | |||
| { | |||
| grad_state.pending_exits_count -= 1; | |||
| if (!to_ops_set.Contains(y.op)) | |||
| grad_state.unused_exits.append(y); | |||
| if (grad_state.pending_exits_count == 0) | |||
| loop_exits.extend(grad_state.unused_exits); | |||
| } | |||
| } | |||
| foreach(var y in grad_state.forward_context.loop_enters) | |||
| { | |||
| if (!pending_count.ContainsKey(y.op.name)) | |||
| pending_count[y.op.name] = 1; | |||
| } | |||
| } | |||
| return loop_exits.ToArray(); | |||
| } | |||
| public void EnterGradWhileContext(Operation op, bool before) | |||
| { | |||
| var grad_state = GetGradState(op, before); | |||
| if (grad_state != null) | |||
| grad_state.grad_context.Enter(); | |||
| } | |||
| public void ExitGradWhileContext(Operation op, bool before) | |||
| { | |||
| var grad_state = GetGradState(op, before); | |||
| if (grad_state != null) | |||
| grad_state.grad_context.Exit(); | |||
| } | |||
| // def AddWhileContext(self, op, between_op_list, between_ops): | |||
| // """Add the grad state for the while loop that op belongs to. | |||
| @@ -118,6 +142,32 @@ namespace Tensorflow.Operations.ControlFlows | |||
| // if loop_exit.op not in between_ops: | |||
| // between_ops.add(loop_exit.op) | |||
| // between_op_list.append(loop_exit.op) | |||
| public void AddWhileContext(Operation op, List<Operation> between_op_list, List<Operation> between_ops) | |||
| { | |||
| var forward_ctxt = op.GetWhileContext(); | |||
| var grad_state = _map.ContainsKey(forward_ctxt) ? _map[forward_ctxt] : null; | |||
| if(grad_state == null) | |||
| { | |||
| GradLoopState outer_grad_state = null; | |||
| var outer_forward_ctxt = forward_ctxt.outer_context; | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); | |||
| if (outer_forward_ctxt != null) | |||
| outer_grad_state = _map[outer_forward_ctxt]; | |||
| grad_state = new GradLoopState(forward_ctxt, outer_grad_state); | |||
| _map[forward_ctxt] = grad_state; | |||
| // We need to include all exits of a loop for backprop. | |||
| foreach (var loop_exit in grad_state.forward_loop_exits) | |||
| { | |||
| if(!between_ops.Contains(loop_exit.op)) | |||
| { | |||
| between_ops.add(loop_exit.op); | |||
| between_op_list.append(loop_exit.op); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // def ZerosLikeForExit(self, val): | |||
| // """Create zeros_like gradient for a loop exit. | |||
| @@ -174,116 +224,101 @@ namespace Tensorflow.Operations.ControlFlows | |||
| // result = array_ops.zeros_like(val, optimize=False) | |||
| // return result | |||
| // def ZerosLike(self, op, index): | |||
| // """Create zeros_like for the specified output of an op. | |||
| // If op is in a while loop that is part of gradients(), this method | |||
| // must be called in its grad loop context. | |||
| // Args: | |||
| // op: A tensorflow operation. | |||
| // index: the index for a specific output of the op. | |||
| // Returns: | |||
| // A zero tensor of the same shape of op.outputs[index]. | |||
| // """ | |||
| // if util.IsLoopSwitch(op): | |||
| // return None | |||
| // if op.graph._building_function: # pylint: disable=protected-access | |||
| // # The optimization here is tricky to apply to functions | |||
| // return array_ops.zeros_like(op.outputs[index]) | |||
| // dead_branch = util.IsSwitch(op) | |||
| // forward_ctxt = _GetWhileContext(op) | |||
| // grad_state = self._map.get(forward_ctxt) | |||
| // if grad_state is None: | |||
| // # op is not in a while loop that is part of gradients(). | |||
| // return ZerosLikeOutsideLoop(op, index) | |||
| // op_ctxt = op._get_control_flow_context() | |||
| // val = ops.convert_to_tensor(op.outputs[index], name="tensor") | |||
| // shape = val.get_shape() | |||
| // if shape.is_fully_defined(): | |||
| // # If the shape is known statically, just create a zero tensor with | |||
| // # the right shape in the grad loop context. | |||
| // result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) | |||
| // if dead_branch: | |||
| // # op is a cond switch. Guard the zero tensor with a switch. | |||
| // pred = grad_state.history_map.get(op_ctxt.pred.name) | |||
| // branch = op_ctxt.branch | |||
| // result = _SwitchRefOrTensor(result, pred)[1 - branch] | |||
| // else: | |||
| // # Unknown shape so keep a history of the shape at runtime. | |||
| // if dead_branch: | |||
| // # Need to add a special switch to guard the value. | |||
| // pred = op_ctxt.pred | |||
| // branch = op_ctxt.branch | |||
| // op_ctxt.outer_context.Enter() | |||
| // val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] | |||
| // zeros_shape = array_ops.shape_internal(val, optimize=False) | |||
| // op_ctxt.outer_context.Exit() | |||
| // val.op._set_control_flow_context(op_ctxt) | |||
| // zeros_shape.op._set_control_flow_context(op_ctxt) | |||
| // else: | |||
| // op_ctxt.Enter() | |||
| // zeros_shape = array_ops.shape_internal(val, optimize=False) | |||
| // op_ctxt.Exit() | |||
| // # Add forward accumulator for shape. | |||
| // grad_state.grad_context.Exit() | |||
| // history_zeros_shape = grad_state.AddForwardAccumulator( | |||
| // zeros_shape, dead_branch=dead_branch) | |||
| // grad_state.grad_context.Enter() | |||
| // # Create a zero tensor with the right shape. | |||
| // shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, | |||
| // zeros_shape, dead_branch) | |||
| // result = array_ops.zeros(shape, val.dtype) | |||
| // return result | |||
| // def PostProcessing(self): | |||
| // """Perform postprocessing at the end of gradients(). | |||
| // We have created the gradient graph at this point. So this function | |||
| // can be used to perform any postprocessing on the gradient graph. | |||
| // We currently perform the following postprocessing: | |||
| // 1. Patch the gradient graph if the output of a loop variable | |||
| // doesn't depend on its input. | |||
| // """ | |||
| // for _, grad_state in self._map.items(): | |||
| // for _, b_merge in grad_state.switch_map.items(): | |||
| // if b_merge.op.inputs[0] == b_merge.op.inputs[1]: | |||
| // # The value of this loop variable at iteration i+1 doesn't | |||
| // # depend on its value at iteration i. So use zeros as the | |||
| // # gradients for all iterations > 0. | |||
| // dtype = b_merge.op.inputs[0].dtype | |||
| // shape = b_merge.op.inputs[0].get_shape() | |||
| // # pylint: disable=protected-access | |||
| // if shape.is_fully_defined(): | |||
| // grad_state.grad_context.Enter() | |||
| // # Create a zeros and use it for iterations > 0. | |||
| // grad_val = constant_op.constant(0, dtype=dtype, shape=shape) | |||
| // next_grad_val = _NextIteration(grad_val) | |||
| // grad_state.grad_context.Exit() | |||
| // else: | |||
| // # Create a zeros in the outer grad context. | |||
| // outer_grad_ctxt = grad_state.grad_context.outer_context | |||
| // if outer_grad_ctxt: | |||
| // outer_grad_ctxt.Enter() | |||
| // enter_grad_op = b_merge.op.inputs[0].op | |||
| // enter_grad = enter_grad_op.inputs[0] | |||
| // grad_shape = array_ops.shape_internal(enter_grad, optimize=False) | |||
| // grad_val = array_ops.zeros(grad_shape) | |||
| // if outer_grad_ctxt: | |||
| // outer_grad_ctxt.Exit() | |||
| // # Use the zeros for iterations > 0. | |||
| // grad_state.grad_context.Enter() | |||
| // next_grad_val = _NextIteration(grad_val) | |||
| // grad_state.grad_context.Exit() | |||
| // b_merge.op._update_input(1, next_grad_val) | |||
| // # pylint: enable=protected-access | |||
| public Tensor ZerosLike(Operation op, int index) | |||
| { | |||
| if (util.IsLoopSwitch(op)) | |||
| return null; | |||
| if (op.graph.building_function) | |||
| return array_ops.zeros_like(op.outputs[index]); | |||
| var dead_branch = util.IsSwitch(op); | |||
| var forward_ctxt = util.GetWhileContext(op); | |||
| var grad_state = _map.get(forward_ctxt); | |||
| // op is not in a while loop that is part of gradients(). | |||
| if (grad_state == null) | |||
| return ZerosLikeOutsideLoop(op, index); | |||
| throw new NotImplementedException("ZerosLike"); | |||
| } | |||
| public Tensor ZerosLikeOutsideLoop(Operation op, int index) | |||
| { | |||
| var val = op.outputs[index]; | |||
| if (!util.IsSwitch(op)) | |||
| { | |||
| if (val.dtype == dtypes.resource) | |||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||
| /*return array_ops.zeros( | |||
| gen_resource_variable_ops.variable_shape(val), | |||
| dtype: default_gradient.get_zeros_dtype(val));*/ | |||
| return array_ops.zeros_like(val, optimize: false); | |||
| } | |||
| else | |||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||
| } | |||
| /// <summary> | |||
| /// Create zeros_like gradient for a loop exit. | |||
| /// </summary> | |||
| /// <param name="val"></param> | |||
| /// <returns></returns> | |||
| public Tensor ZerosLikeForExit(Tensor val) | |||
| { | |||
| Tensor result = null; | |||
| var val_shape = val.TensorShape; | |||
| var forward_ctxt = val.op._get_control_flow_context(); | |||
| var outer_forward_ctxt = forward_ctxt.outer_context; | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); | |||
| GradLoopState outer_grad_state = null; | |||
| if (outer_forward_ctxt != null) | |||
| outer_grad_state = _map.get(outer_forward_ctxt); | |||
| // This is a nested loop. | |||
| if (outer_grad_state != null) | |||
| { | |||
| throw new NotImplementedException("ZerosLikeForExit"); | |||
| } | |||
| else | |||
| { | |||
| // If the shape is known statically, just create a zero tensor | |||
| // with the right shape. | |||
| if (val_shape.is_fully_defined()) | |||
| result = array_ops.zeros(val_shape.dims, val.dtype); | |||
| else | |||
| result = array_ops.zeros_like(val, optimize: false); | |||
| } | |||
| return result; | |||
| } | |||
| public void PostProcessing() | |||
| { | |||
| foreach(var grad_state in _map.Values) | |||
| { | |||
| foreach(var b_merge in grad_state.switch_map.Values) | |||
| { | |||
| if(b_merge.op.inputs[0] == b_merge.op.inputs[1]) | |||
| { | |||
| Tensor next_grad_val = null; | |||
| // The value of this loop variable at iteration i+1 doesn't | |||
| // depend on its value at iteration i. So use zeros as the | |||
| // gradients for all iterations > 0. | |||
| var dtype = b_merge.op.inputs[0].dtype; | |||
| var shape = b_merge.op.inputs[0].TensorShape; | |||
| if (shape.is_fully_defined()) | |||
| { | |||
| grad_state.grad_context.Enter(); | |||
| // Create a zeros and use it for iterations > 0. | |||
| var grad_val = constant_op.constant(0, dtype: dtype, shape: shape); | |||
| next_grad_val = control_flow_ops._NextIteration(grad_val); | |||
| grad_state.grad_context.Exit(); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("PostProcessing shape is not fully defined."); | |||
| } | |||
| b_merge.op._update_input(1, next_grad_val); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -16,41 +16,18 @@ | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| using util = Tensorflow.control_flow_util; | |||
| namespace Tensorflow.Operations.ControlFlows | |||
| { | |||
| /// <summary> | |||
| /// The state used for constructing the gradient graph for a while loop. | |||
| /// </summary> | |||
| public class GradLoopState | |||
| { | |||
| //class GradLoopState(object): | |||
| // """The state used for constructing the gradient graph for a while loop. | |||
| // We create a GradLoopState for each while loop in forward and its | |||
| // corresponding while loop in backprop. This gives us access to both | |||
| // the forward and the backprop WhileContexts. | |||
| // During the construction of gradient graph, any time when we detect | |||
| // a forward value that is needed for backprop, we create a history | |||
| // accumulator and add it to `history_map`. Any time when we backprop | |||
| // a loop switch op (in _SwitchGrad), we add the grad merge op in | |||
| // `switch_map`. | |||
| // """ | |||
| // def __init__(self, forward_ctxt, outer_grad_state): | |||
| // # The grad loop state for the outer while loop. | |||
| // self._outer_grad_state = None | |||
| // # The while loop context for forward. | |||
| // self._forward_context = None | |||
| // # The loop counter added by AddForwardLoopCounter. It is the value | |||
| // # of the loop counter for the next iteration. | |||
| // self._forward_index = None | |||
| // # A sync op for forward. | |||
| // self._forward_sync = None | |||
| // # The while loop context for backprop. | |||
| private WhileContext _grad_context = null; | |||
| public WhileContext grad_context => _grad_context; | |||
| @@ -65,156 +42,112 @@ namespace Tensorflow.Operations.ControlFlows | |||
| // # Information needed by backprop. | |||
| private Hashtable _history_map = new Hashtable(); | |||
| public Hashtable history_map => _history_map; | |||
| private Hashtable _switch_map = new Hashtable(); | |||
| public Hashtable switch_map => _switch_map; | |||
| // self._unused_exits = [] | |||
| // self._deferred_exits = [] | |||
| // self._forward_loop_exits = list(forward_ctxt.loop_exits) | |||
| // self._pending_exits_count = len(forward_ctxt.loop_exits) | |||
| // self._outer_grad_state = outer_grad_state | |||
| // if outer_grad_state: | |||
| // outer_forward_ctxt = outer_grad_state.forward_context | |||
| // else: | |||
| // if not hasattr(forward_ctxt, "outer_context"): | |||
| // raise ValueError("Failed to call gradients on a while loop without" | |||
| // "properly serializing graph via MetaGraphDef") | |||
| // outer_forward_ctxt = forward_ctxt.outer_context | |||
| // # Add the forward loop counter. | |||
| // with forward_ctxt._graph.as_default(): # pylint: disable=protected-access | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Enter() | |||
| // cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Exit() | |||
| // self._forward_context = forward_ctxt | |||
| // self._forward_index = forward_index | |||
| // # Add the backprop WhileContext, and the backprop loop counter. | |||
| // if outer_grad_state: | |||
| // # This is a nested loop. Remember the iteration counts for each | |||
| // # execution of this inner loop. | |||
| // outer_forward_ctxt.AddName(cnt.name) | |||
| // history_cnt = outer_grad_state.AddForwardAccumulator(cnt) | |||
| // outer_grad_ctxt = outer_grad_state.grad_context | |||
| // outer_grad_ctxt.Enter() | |||
| // self._grad_context = WhileContext( | |||
| // maximum_iterations=forward_ctxt.maximum_iterations, | |||
| // parallel_iterations=forward_ctxt.parallel_iterations, | |||
| // back_prop=forward_ctxt.back_prop, | |||
| // swap_memory=forward_ctxt.swap_memory, | |||
| // name=forward_ctxt.name, | |||
| // grad_state=self) | |||
| // real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) | |||
| // self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||
| // real_cnt, outer_grad_state) | |||
| // outer_grad_ctxt.Exit() | |||
| // else: | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Enter() | |||
| // self._grad_context = WhileContext( | |||
| // maximum_iterations=forward_ctxt.maximum_iterations, | |||
| // parallel_iterations=forward_ctxt.parallel_iterations, | |||
| // back_prop=forward_ctxt.back_prop, | |||
| // swap_memory=forward_ctxt.swap_memory, | |||
| // name=forward_ctxt.name, | |||
| // grad_state=self) | |||
| // self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||
| // cnt, outer_grad_state) | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Exit() | |||
| // @property | |||
| // def outer_grad_state(self): | |||
| // """The grad loop state for outer loop.""" | |||
| // return self._outer_grad_state | |||
| Dictionary<Operation, Tensor> _switch_map = new Dictionary<Operation, Tensor>(); | |||
| public Dictionary<Operation, Tensor> switch_map => _switch_map; | |||
| // @property | |||
| // def forward_context(self): | |||
| // """The while loop context for forward.""" | |||
| // return self._forward_context | |||
| // @property | |||
| // def forward_index(self): | |||
| // """The loop index of forward loop.""" | |||
| // return self._forward_index | |||
| // @property | |||
| // def forward_sync(self): | |||
| // """A control trigger node for synchronization in the forward loop. | |||
| // One main use is to keep the push ops of a stack executed in the | |||
| // iteration order. | |||
| // """ | |||
| // if self._forward_sync is None: | |||
| // with ops.control_dependencies(None): | |||
| // self._forward_sync = control_trigger(name="f_sync") | |||
| // self._forward_sync._set_control_flow_context(self._forward_context) | |||
| // self._forward_index.op._add_control_input(self._forward_sync) | |||
| // return self._forward_sync | |||
| // @property | |||
| // def grad_context(self): | |||
| // """The corresponding WhileContext for gradient.""" | |||
| // return self._grad_context | |||
| // @property | |||
| // def grad_index(self): | |||
| // """The loop index of backprop loop.""" | |||
| // return self._grad_index | |||
| // @property | |||
| // def grad_sync(self): | |||
| // """A control trigger node for synchronization in the grad loop. | |||
| /// <summary> | |||
| /// The while loop context for forward. | |||
| /// </summary> | |||
| WhileContext _forward_context; | |||
| public WhileContext forward_context => _forward_context; | |||
| // One main use is to keep the pop ops of a stack executed in the | |||
| // iteration order. | |||
| // """ | |||
| // if self._grad_sync is None: | |||
| // with ops.control_dependencies(None): | |||
| // self._grad_sync = control_trigger(name="b_sync") | |||
| // self._grad_sync._set_control_flow_context(self._grad_context) | |||
| // self._grad_index.op._add_control_input(self._grad_sync) | |||
| // if self._grad_context.outer_context: | |||
| // self._grad_context.outer_context.AddInnerOp(self._grad_sync) | |||
| // return self._grad_sync | |||
| /// <summary> | |||
| /// The grad loop state for the outer while loop. | |||
| /// </summary> | |||
| GradLoopState _outer_grad_state; | |||
| public GradLoopState outer_grad_state => _outer_grad_state; | |||
| // @property | |||
| // def history_map(self): | |||
| // """The map that records all the tensors needed for backprop.""" | |||
| // return self._history_map | |||
| Tensor _forward_index; | |||
| public Tensor forward_index => _forward_index; | |||
| Tensor _grad_index; | |||
| // @property | |||
| // def switch_map(self): | |||
| // """The map that records all the Switch ops for the while loop.""" | |||
| // return self._switch_map | |||
| Tensor[] _forward_loop_exits; | |||
| /// <summary> | |||
| /// The list of exits of the forward loop. | |||
| /// </summary> | |||
| public Tensor[] forward_loop_exits => _forward_loop_exits; | |||
| // @property | |||
| // def unused_exits(self): | |||
| // """The list of "unused" exits.""" | |||
| // return self._unused_exits | |||
| List<Tensor> _deferred_exits; | |||
| public List<Tensor> deferred_exits => _deferred_exits; | |||
| // @property | |||
| // def deferred_exits(self): | |||
| // """The list of "deferred" exits.""" | |||
| // return self._deferred_exits | |||
| List<Tensor> _unused_exits; | |||
| public List<Tensor> unused_exits => _unused_exits; | |||
| // @property | |||
| // def forward_loop_exits(self): | |||
| // """The list of exits of the forward loop.""" | |||
| // return self._forward_loop_exits | |||
| /// <summary> | |||
| /// The number of exits we expect to see but haven't. | |||
| /// </summary> | |||
| public int pending_exits_count { get; set; } | |||
| // @property | |||
| // def pending_exits_count(self): | |||
| // """The number of exits we expect to see but haven't.""" | |||
| // return self._pending_exits_count | |||
| Operation _grad_sync; | |||
| public Operation grad_sync | |||
| { | |||
| get | |||
| { | |||
| if(_grad_sync == null) | |||
| { | |||
| tf_with(ops.control_dependencies(null), delegate | |||
| { | |||
| _grad_sync = gen_control_flow_ops.control_trigger(name: "b_sync"); | |||
| }); | |||
| _grad_sync._set_control_flow_context(_grad_context); | |||
| _grad_index.op._add_control_input(_grad_sync); | |||
| if (_grad_context.outer_context != null) | |||
| _grad_context.outer_context.AddInnerOp(_grad_sync); | |||
| } | |||
| return _grad_sync; | |||
| } | |||
| } | |||
| // @pending_exits_count.setter | |||
| // def pending_exits_count(self, cnt): | |||
| // """Set the pending count to cnt.""" | |||
| // self._pending_exits_count = cnt | |||
| public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) | |||
| { | |||
| // Information needed by backprop. | |||
| _unused_exits = new List<Tensor>(); | |||
| _deferred_exits = new List<Tensor>(); | |||
| _forward_loop_exits = list(forward_ctxt.loop_exits); | |||
| pending_exits_count = len(forward_ctxt.loop_exits); | |||
| _outer_grad_state = outer_grad_state_; | |||
| ControlFlowContext outer_forward_ctxt = null; | |||
| if (outer_grad_state_ != null) | |||
| outer_forward_ctxt = outer_grad_state_.forward_context; | |||
| // Add the forward loop counter. | |||
| // with forward_ctxt._graph.as_default(): | |||
| Tensor cnt, forward_index; | |||
| { | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt.Enter(); | |||
| (cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state); | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt.Exit(); | |||
| } | |||
| _forward_context = forward_ctxt; | |||
| _forward_index = forward_index; | |||
| // Add the backprop WhileContext, and the backprop loop counter. | |||
| if (outer_grad_state != null) | |||
| { | |||
| // This is a nested loop. Remember the iteration counts for each | |||
| // execution of this inner loop. | |||
| throw new NotImplementedException("GradLoopState"); | |||
| } | |||
| else | |||
| { | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt.Enter(); | |||
| _grad_context = new WhileContext( | |||
| maximum_iterations: forward_ctxt.maximum_iterations, | |||
| parallel_iterations: forward_ctxt.parallel_iterations, | |||
| back_prop: forward_ctxt.back_prop, | |||
| swap_memory: forward_ctxt.swap_memory, | |||
| name: forward_ctxt.name, | |||
| grad_state: this); | |||
| _grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state); | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt.Exit(); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Add an accumulator for each forward tensor that is needed in backprop. | |||
| @@ -242,63 +175,52 @@ namespace Tensorflow.Operations.ControlFlows | |||
| /// <returns>The stack that contains the accumulated history of the tensor.</returns> | |||
| public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) | |||
| { | |||
| throw new NotImplementedException("AddForwardAccumulator"); | |||
| // # curr_ctxt is the context that tf.gradients was called in. | |||
| // with self._forward_index.graph.as_default(): | |||
| // curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access | |||
| // with ops.control_dependencies(None): | |||
| // if curr_ctxt: | |||
| // curr_ctxt.Enter() | |||
| // with ops.colocate_with(value): | |||
| // # We only need to pass maximum_iterations to the stack if | |||
| // # we're inside an XLA context. | |||
| // if not util.IsInXLAContext(value.op): | |||
| // max_size = constant_op.constant(-1, dtypes.int32) | |||
| // else: | |||
| // max_size = GetMaxSizeFromNestedMaximumIterations( | |||
| // value, self.forward_context) | |||
| // acc = gen_data_flow_ops.stack_v2( | |||
| // max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") | |||
| // if curr_ctxt: | |||
| // curr_ctxt.Exit() | |||
| // # Make acc available in the forward context. | |||
| // enter_acc = self.forward_context.AddValue(acc) | |||
| // # Add the stack_push op in the context of value.op. | |||
| // swap_enabled = self.forward_context.swap_memory | |||
| // value_ctxt = util.GetOutputContext(value.op) | |||
| // if value_ctxt == self.forward_context: | |||
| // # value is not nested in the forward context. | |||
| // self.forward_context.Enter() | |||
| // push = gen_data_flow_ops.stack_push_v2( | |||
| // enter_acc, value, swap_memory=swap_enabled) | |||
| // self.forward_context.Exit() | |||
| // # Protect stack push and order it before forward_index. | |||
| // self.forward_index.op._add_control_input(push.op) | |||
| // else: | |||
| // # value is in a cond context within the forward context. | |||
| // if not isinstance(value_ctxt, CondContext): | |||
| // raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) | |||
| // if dead_branch: | |||
| // # The special case for creating a zero tensor for a dead | |||
| // # branch of a switch. See ControlFlowState.ZerosLike(). | |||
| // value_ctxt.outer_context.Enter() | |||
| // push = gen_data_flow_ops.stack_push_v2( | |||
| // enter_acc, value, swap_memory=swap_enabled) | |||
| // value_ctxt.outer_context.Exit() | |||
| // push.op._set_control_flow_context(value_ctxt) | |||
| // else: | |||
| // value_ctxt.Enter() | |||
| // push = gen_data_flow_ops.stack_push_v2( | |||
| // enter_acc, value, swap_memory=swap_enabled) | |||
| // value_ctxt.Exit() | |||
| // # Protect stack push and order it before forward_sync. | |||
| // self.forward_sync._add_control_input(push.op) | |||
| // # Order stack push after the successor of forward_index | |||
| // add_op = self.forward_index.op.inputs[0].op | |||
| // push.op._add_control_input(add_op) | |||
| // return acc | |||
| _forward_index.graph.as_default(); | |||
| { | |||
| var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); | |||
| return tf_with(ops.control_dependencies(null), delegate | |||
| { | |||
| Tensor acc = null; | |||
| Tensor push = null; | |||
| if (curr_ctxt != null) | |||
| curr_ctxt.Enter(); | |||
| ops.colocate_with(value); | |||
| { | |||
| // We only need to pass maximum_iterations to the stack if | |||
| // we're inside an XLA context. | |||
| var max_size = constant_op.constant(-1, dtypes.int32); | |||
| acc = gen_data_flow_ops.stack_v2( | |||
| max_size: max_size, elem_type: value.dtype.as_base_dtype(), name: "f_acc"); | |||
| } | |||
| if (curr_ctxt != null) | |||
| curr_ctxt.Exit(); | |||
| // Make acc available in the forward context. | |||
| var enter_acc = forward_context.AddValue(acc); | |||
| // Add the stack_push op in the context of value.op. | |||
| var swap_enabled = forward_context.swap_memory; | |||
| var value_ctxt = util.GetOutputContext(value.op); | |||
| if(value_ctxt == forward_context) | |||
| { | |||
| // value is not nested in the forward context. | |||
| forward_context.Enter(); | |||
| push = gen_data_flow_ops.stack_push_v2(enter_acc, value, swap_memory: swap_enabled); | |||
| forward_context.Exit(); | |||
| // Protect stack push and order it before forward_index. | |||
| forward_index.op._add_control_input(push.op); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("AddForwardAccumulator"); | |||
| } | |||
| // Order stack push after the successor of forward_index | |||
| var add_op = forward_index.op.inputs[0].op; | |||
| push.op._add_control_input(add_op); | |||
| return acc; | |||
| }); | |||
| } | |||
| } | |||
| // """Add the getter for an accumulated value in the grad context. | |||
| @@ -315,98 +237,99 @@ namespace Tensorflow.Operations.ControlFlows | |||
| // Returns: | |||
| // The current value (the top of the stack). | |||
| // """ | |||
| public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) | |||
| { | |||
| throw new NotImplementedException(); | |||
| // history_ctxt = history_value.op._get_control_flow_context() | |||
| // # Find the cond context that controls history_value if any. | |||
| // cond_ctxt = None | |||
| // value_ctxt = value.op._get_control_flow_context() | |||
| // while value_ctxt and value_ctxt != history_ctxt: | |||
| // if isinstance(value_ctxt, CondContext): | |||
| // cond_ctxt = value_ctxt | |||
| // break | |||
| // value_ctxt = value_ctxt.outer_context | |||
| // with ops.control_dependencies(None): | |||
| // self.grad_context.Enter() | |||
| // if cond_ctxt: | |||
| // # Guard stack pop with a switch if it is controlled by a cond. | |||
| // grad_state = self | |||
| // pred = None | |||
| // while pred is None and grad_state: | |||
| // pred = grad_state.history_map.get(cond_ctxt.pred.name) | |||
| // grad_state = grad_state.outer_grad_state | |||
| // if pred is None: | |||
| // pred = cond_ctxt.pred | |||
| // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch | |||
| // history_value = _SwitchRefOrTensor(history_value, pred)[branch] | |||
| // pop = gen_data_flow_ops.stack_pop_v2(history_value, | |||
| // value.dtype.base_dtype) | |||
| // pop.set_shape(value.get_shape()) | |||
| // self.grad_context.Exit() | |||
| // parallel_iterations = self.grad_context.parallel_iterations | |||
| // if parallel_iterations > 1: | |||
| // # All pops are ordered after pivot_for_body and before grad_sync. | |||
| // self.grad_sync._add_control_input(pop.op) | |||
| // return pop | |||
| var history_ctxt = history_value.op._get_control_flow_context(); | |||
| // Find the cond context that controls history_value if any. | |||
| CondContext cond_ctxt = null; | |||
| Tensor pop = null; | |||
| var value_ctxt = value.op._get_control_flow_context(); | |||
| while(value_ctxt != null && value_ctxt != history_ctxt) | |||
| { | |||
| if (value_ctxt is CondContext cc) | |||
| cond_ctxt = cc; | |||
| value_ctxt = value_ctxt.outer_context; | |||
| } | |||
| tf_with(ops.control_dependencies(null), delegate | |||
| { | |||
| grad_context.Enter(); | |||
| if(cond_ctxt != null) | |||
| { | |||
| throw new NotImplementedException("AddBackpropAccumulatedValue"); | |||
| } | |||
| pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype()); | |||
| pop.set_shape(value.TensorShape); | |||
| grad_context.Exit(); | |||
| }); | |||
| var parallel_iterations = grad_context.parallel_iterations; | |||
| if (parallel_iterations > 1) | |||
| // All pops are ordered after pivot_for_body and before grad_sync. | |||
| grad_sync._add_control_input(pop.op); | |||
| return pop; | |||
| } | |||
| // def GetRealValue(self, value): | |||
| // """Get the real value of `value`. | |||
| // If backprop "uses" a value produced by forward inference, an accumulator | |||
| // is added in the forward loop to accumulate its values. We use the | |||
| // accumulated value. This method must be called in the grad loop context. | |||
| // `value` must be in forward and needed for backprop. | |||
| // Args: | |||
| // value: A tensor to be captured. | |||
| // Returns: | |||
| // The same tensor obtained from the saved history. | |||
| // """ | |||
| // assert value.op.type not in ["Variable", "VariableV2"] | |||
| // real_value = self._history_map.get(value.name) | |||
| // if real_value is None: | |||
| // cur_value = value | |||
| // cur_grad_state = self | |||
| // while True: | |||
| // enter_op = util.GetLoopConstantEnter(cur_value) | |||
| // if enter_op: | |||
| // # Special case: cur_value comes from a constant Enter node. | |||
| // cur_value = enter_op.inputs[0] | |||
| // cur_grad_state = cur_grad_state.outer_grad_state | |||
| // if cur_grad_state is None: | |||
| // # We are now outside all nested loops for this gradient(), | |||
| // # so `value` is a loop invariant and there is no need to | |||
| // # save the history of value. Just make cur_value to enter | |||
| // # the right control flow context. | |||
| // real_value = self._grad_context.AddValue(cur_value) | |||
| // break | |||
| // elif constant_op.is_constant(cur_value): | |||
| // # If the value to be forwarded is a constant, clone the constant in | |||
| // # the gradient loop rather than using a stack. | |||
| // # TODO(phawkins): consider hoisting the constant out of the loop | |||
| // # instead. | |||
| // real_value = constant_op.constant( | |||
| // tensor_util.constant_value(cur_value), dtype=cur_value.dtype) | |||
| // break | |||
| // else: | |||
| // # Record the history of this value in forward_ctxt. | |||
| // self._grad_context.Exit() | |||
| // history_value = cur_grad_state.AddForwardAccumulator(cur_value) | |||
| // self._grad_context.Enter() | |||
| // break | |||
| // if real_value is None: | |||
| // # Add the stack pop op in the grad context. | |||
| // real_value = cur_grad_state.AddBackpropAccumulatedValue( | |||
| // history_value, cur_value) | |||
| // if cur_grad_state != self: | |||
| // real_value = self._grad_context.AddValue(real_value) | |||
| // self._history_map[value.name] = real_value | |||
| // return real_value | |||
| /// <summary> | |||
| /// Get the real value of `value`. | |||
| /// </summary> | |||
| /// <param name="value">A tensor to be captured.</param> | |||
| /// <returns>The same tensor obtained from the saved history.</returns> | |||
| public Tensor GetRealValue(Tensor value) | |||
| { | |||
| Tensor real_value = null; | |||
| if(real_value == null) | |||
| { | |||
| var cur_value = value; | |||
| var cur_grad_state = this; | |||
| Tensor history_value = null; | |||
| while (true) | |||
| { | |||
| var enter_op = util.GetLoopConstantEnter(cur_value); | |||
| if(enter_op != null) | |||
| { | |||
| // Special case: cur_value comes from a constant Enter node. | |||
| cur_value = enter_op.inputs[0]; | |||
| cur_grad_state = cur_grad_state.outer_grad_state; | |||
| if(cur_grad_state == null) | |||
| { | |||
| // We are now outside all nested loops for this gradient(), | |||
| // so `value` is a loop invariant and there is no need to | |||
| // save the history of value. Just make cur_value to enter | |||
| // the right control flow context. | |||
| real_value = _grad_context.AddValue(cur_value); | |||
| break; | |||
| } | |||
| } | |||
| else if (constant_op.is_constant(cur_value)) | |||
| { | |||
| // We are now outside all nested loops for this gradient(), | |||
| // so `value` is a loop invariant and there is no need to | |||
| // save the history of value. Just make cur_value to enter | |||
| // the right control flow context. | |||
| real_value = constant_op.constant( | |||
| tensor_util.constant_value(cur_value), dtype: cur_value.dtype); | |||
| break; | |||
| } | |||
| else | |||
| { | |||
| // Record the history of this value in forward_ctxt. | |||
| _grad_context.Exit(); | |||
| history_value = cur_grad_state.AddForwardAccumulator(cur_value); | |||
| _grad_context.Enter(); | |||
| break; | |||
| } | |||
| } | |||
| if(real_value == null) | |||
| { | |||
| // Add the stack pop op in the grad context. | |||
| real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value, cur_value); | |||
| if (cur_grad_state != this) | |||
| real_value = _grad_context.AddValue(real_value); | |||
| } | |||
| _history_map[value.name] = real_value; | |||
| } | |||
| return real_value; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,43 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| internal class LoopVar<TItem> : ICanBeFlattened, IPackable<LoopVar<TItem>> | |||
| { | |||
| public Tensor Counter { get; set; } | |||
| public TItem Item { get; set; } | |||
| public LoopVar(Tensor counter, TItem item) | |||
| { | |||
| Counter = counter; | |||
| Item = item; | |||
| } | |||
| public object[] Flatten() | |||
| { | |||
| var elements = new List<object> { Counter }; | |||
| if (typeof(TItem).GetInterface(typeof(ICanBeFlattened).Name) != null) | |||
| elements.AddRange((Item as ICanBeFlattened).Flatten()); | |||
| else | |||
| elements.Add(Item); | |||
| return elements.ToArray(); | |||
| } | |||
| public LoopVar<TItem> Pack(object[] sequences) | |||
| { | |||
| var counter = sequences[0] as Tensor; | |||
| var item = default(TItem); | |||
| if (typeof(TItem).GetInterface(typeof(IPackable<TItem>).Name) != null) | |||
| item = (Item as IPackable<TItem>).Pack(sequences.Skip(1).ToArray()); | |||
| return new LoopVar<TItem>(counter, item); | |||
| } | |||
| public static implicit operator (Tensor, TItem)(LoopVar<TItem> loopVar) | |||
| { | |||
| return (loopVar.Counter, loopVar.Item); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,36 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| public class MergeOutput | |||
| { | |||
| Tensor output; | |||
| Tensor value_index; | |||
| public MergeOutput(Tensor[] values) | |||
| { | |||
| output = values[0]; | |||
| value_index = values[1]; | |||
| } | |||
| public Tensor this[int idx] | |||
| { | |||
| get | |||
| { | |||
| switch(idx) | |||
| { | |||
| case 0: | |||
| return output; | |||
| case 1: | |||
| return value_index; | |||
| default: | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| public static implicit operator Tensor(MergeOutput merge) | |||
| => merge.output; | |||
| } | |||
| } | |||
| @@ -32,17 +32,22 @@ namespace Tensorflow.Operations | |||
| bool _back_prop=true; | |||
| GradLoopState _grad_state =null; | |||
| Tensor _maximum_iterations; | |||
| public Tensor maximum_iterations => _maximum_iterations; | |||
| int _parallel_iterations; | |||
| public int parallel_iterations => _parallel_iterations; | |||
| bool _swap_memory; | |||
| public bool swap_memory => _swap_memory; | |||
| Tensor _pivot_for_pred; | |||
| Tensor _pivot_for_body; | |||
| List<Tensor> _loop_exits; | |||
| public List<Tensor> loop_exits => _loop_exits; | |||
| List<Tensor> _loop_enters; | |||
| public List<Tensor> loop_enters => _loop_enters; | |||
| Graph _graph; | |||
| public override GradLoopState grad_state => _grad_state; | |||
| public override bool back_prop => _back_prop; | |||
| public WhileContext(int? maximum_iterations = null, | |||
| public WhileContext(Tensor maximum_iterations = null, | |||
| int parallel_iterations = 10, | |||
| bool back_prop = true, | |||
| bool swap_memory = false, | |||
| @@ -64,13 +69,15 @@ namespace Tensorflow.Operations | |||
| _grad_state = grad_state; | |||
| } | |||
| private void _init_from_args(int? maximum_iterations, | |||
| private void _init_from_args(Tensor maximum_iterations, | |||
| int parallel_iterations, | |||
| bool back_prop, | |||
| bool swap_memory, | |||
| string name) | |||
| { | |||
| _name = ops.get_default_graph().unique_name(name); | |||
| _maximum_iterations = maximum_iterations; | |||
| _parallel_iterations = parallel_iterations; | |||
| _back_prop = back_prop; | |||
| _swap_memory = swap_memory; | |||
| _loop_exits = new List<Tensor>(); | |||
| @@ -107,37 +114,75 @@ namespace Tensorflow.Operations | |||
| /// <summary> | |||
| /// Add the loop termination condition and body to the graph. | |||
| /// </summary> | |||
| public Tensor[] BuildLoop(Func<Tensor, Tensor> pred, | |||
| Func<Tensor, Tensor> body, | |||
| Tensor[] loop_vars, | |||
| TensorShape shape_invariants, | |||
| internal LoopVar<TItem> BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, | |||
| Func<LoopVar<TItem>, LoopVar<TItem>> body, | |||
| LoopVar<TItem> loop_vars, | |||
| TensorShape[] shape_invariants, | |||
| bool return_same_structure) | |||
| { | |||
| // Keep original_loop_vars to identify which are TensorArrays | |||
| var original_loop_vars = loop_vars; | |||
| // Convert TensorArrays to their flow variables | |||
| var loop_vars_tensors = nest.flatten2(loop_vars) | |||
| .Select(x => _convert_tensorarray_to_flow(x)) | |||
| .ToArray(); | |||
| if (shape_invariants == null) | |||
| shape_invariants = loop_vars_tensors | |||
| .Select(x => _get_shape_invariant(x as Tensor)) | |||
| .ToArray(); | |||
| Enter(); | |||
| var(original_body_result, exit_vars) = _BuildLoop( | |||
| pred, body, original_loop_vars, loop_vars, shape_invariants); | |||
| pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); | |||
| Exit(); | |||
| var flat_result = original_body_result; | |||
| var flat_result = nest.flatten2(original_body_result) | |||
| .Select(x => x as ITensorOrTensorArray) | |||
| .ToArray(); | |||
| var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); | |||
| var packed_exit_vars = nest.pack_sequence_as( | |||
| var packed_exit_vars = nest.pack_sequence_as2( | |||
| structure: original_body_result, | |||
| flat_sequence: exit_vars_with_tensor_arrays); | |||
| return packed_exit_vars as Tensor[]; | |||
| return packed_exit_vars; | |||
| } | |||
| private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred, | |||
| Func<Tensor, Tensor> body, | |||
| Tensor[] original_loop_vars, | |||
| private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array) | |||
| { | |||
| if (tensor_or_tensor_array is TensorArray tensor_array) | |||
| return tensor_array.flow; | |||
| else if (tensor_or_tensor_array is Tensor tensor) | |||
| return tensor; | |||
| throw new NotImplementedException("_convert_tensorarray_to_flow"); | |||
| } | |||
| private TensorShape _get_shape_invariant(Tensor var, int[] shape = null) | |||
| { | |||
| return var.TensorShape; | |||
| } | |||
| /// <summary> | |||
| /// Add the loop termination condition and body to the graph. | |||
| /// </summary> | |||
| /// <typeparam name="TItem"></typeparam> | |||
| /// <param name="pred"></param> | |||
| /// <param name="body"></param> | |||
| /// <param name="original_loop_vars"></param> | |||
| /// <param name="loop_vars"></param> | |||
| /// <param name="shape_invariants"></param> | |||
| /// <returns></returns> | |||
| private (LoopVar<TItem>, Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, | |||
| Func<LoopVar<TItem>, LoopVar<TItem>> body, | |||
| LoopVar<TItem> original_loop_vars, | |||
| Tensor[] loop_vars, | |||
| TensorShape shape_invariants) | |||
| TensorShape[] shape_invariants) | |||
| { | |||
| var flat_loop_vars = original_loop_vars; | |||
| var flat_loop_vars = nest.flatten2(original_loop_vars) | |||
| .Select(x => (ITensorOrTensorArray)x) | |||
| .ToArray(); | |||
| // Let the context know the loop variables so the loop variables | |||
| // would be added in the outer contexts properly. | |||
| @@ -146,14 +191,14 @@ namespace Tensorflow.Operations | |||
| Tensor[] enter_vars = null; | |||
| tf_with(ops.control_dependencies(null), delegate | |||
| { | |||
| enter_vars = real_vars.Select(x => _Enter(x, | |||
| enter_vars = real_vars.Select(x => control_flow_ops._Enter(x, | |||
| _name, | |||
| is_constant: false, | |||
| parallel_iterations: _parallel_iterations, | |||
| use_input_shape: shape_invariants == null)) | |||
| .ToArray(); | |||
| foreach(var x in enter_vars) | |||
| foreach (var x in enter_vars) | |||
| { | |||
| x.graph.prevent_feeding(x); | |||
| if (_outer_context != null) | |||
| @@ -163,7 +208,13 @@ namespace Tensorflow.Operations | |||
| // Finds the closest enclosing non-None control pivot. | |||
| var outer_context = _outer_context; | |||
| while (outer_context != null) | |||
| object control_pivot = null; | |||
| while (outer_context != null && control_pivot == null) | |||
| { | |||
| } | |||
| if (control_pivot != null) | |||
| { | |||
| } | |||
| @@ -177,31 +228,42 @@ namespace Tensorflow.Operations | |||
| var merge_vars = enter_vars | |||
| .Select(x => merge(new[] { x, x })) | |||
| .Select(m => (Tensor)m) | |||
| .ToArray(); | |||
| _pivot_for_pred = merge_vars[0]; | |||
| // Build the graph for pred. | |||
| var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); | |||
| // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); | |||
| var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0])); | |||
| //var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays, expand_composites: true); | |||
| var packed_vars = new LoopVar<TItem>((Tensor)merge_vars_with_tensor_arrays[0], | |||
| (TItem)(object)new BodyItemInRnnWhileLoop((Tensor)merge_vars_with_tensor_arrays[1], | |||
| new[] { (TensorArray)merge_vars_with_tensor_arrays[2] }, | |||
| (Tensor)merge_vars_with_tensor_arrays[3])); | |||
| var pp = pred(packed_vars); | |||
| var c = ops.convert_to_tensor(pp); | |||
| _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); | |||
| var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) | |||
| .ToArray(); | |||
| // Build the graph for body. | |||
| var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | |||
| _pivot_for_body = vars_for_body[0]; | |||
| // Convert TensorArray flow variables inside the context back into | |||
| // their associated TensorArrays for calling the body. | |||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||
| var body_result = body(packed_vars_for_body[0]); | |||
| var vars_for_body_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||
| var packed_vars_for_body = nest.pack_sequence_as2(original_loop_vars, vars_for_body_with_tensor_arrays); | |||
| var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| var body_result = body(packed_vars_for_body); | |||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| // Store body_result to keep track of TensorArrays returned by body | |||
| var original_body_result = new[] { body_result }; | |||
| var original_body_result = body_result; | |||
| // Convert TensorArrays returned by body into their flow variables | |||
| var result = new[] { body_result }; | |||
| var result = nest.flatten2(body_result) | |||
| .Select(x => _convert_tensorarray_to_flow(x)) | |||
| .ToArray(); | |||
| // result = ops.convert_n_to_tensor_or_composite(result); | |||
| var next_vars = new List<Tensor>(); | |||
| foreach (var (m, v) in zip(merge_vars, result)) | |||
| next_vars.Add(_AddNextAndBackEdge(m, v)); | |||
| @@ -218,20 +280,45 @@ namespace Tensorflow.Operations | |||
| private void _FixControlInputsAndContext(Tensor[] enters) | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| foreach(var e in enters) | |||
| foreach(var x in enters) | |||
| { | |||
| var inp_op = e.op.inputs[0].op; | |||
| var inp_op = x.op.inputs[0].op; | |||
| var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); | |||
| var outer_control_inputs = new List<Operation>(); | |||
| foreach(Operation op in control_inputs) | |||
| { | |||
| // We need to keep control inputs that are in any ancestor | |||
| // ControlFlowContext, and within outer WhileContext. | |||
| var keep_as_control_input = true; | |||
| var op_ctxt = control_flow_util.GetOutputContext(op); | |||
| var outer_ctxt = outer_context; | |||
| var outer_while_context = outer_ctxt == null ? null : outer_ctxt.GetWhileContext(); | |||
| while (outer_ctxt != op_ctxt) | |||
| { | |||
| if (outer_ctxt == null || outer_ctxt == outer_while_context) | |||
| { | |||
| keep_as_control_input = false; | |||
| break; | |||
| } | |||
| outer_ctxt = outer_ctxt.outer_context; | |||
| } | |||
| if (keep_as_control_input) | |||
| outer_control_inputs.append(op); | |||
| } | |||
| // op for op in control_inputs if self._IsInOuterContext(op) | |||
| var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | |||
| /*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | |||
| .Select(x => x.op) | |||
| .ToArray(); | |||
| e.op._set_control_flow_context(this); | |||
| e.op._add_control_inputs(outer_control_inputs); | |||
| graph._record_op_seen_by_control_dependencies(e.op); | |||
| .ToArray();*/ | |||
| x.op._set_control_flow_context(this); | |||
| x.op._add_control_inputs(outer_control_inputs.ToArray()); | |||
| graph._record_op_seen_by_control_dependencies(x.op); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Makes the values known to this context. | |||
| /// </summary> | |||
| /// <param name="values"></param> | |||
| private void _InitializeValues(Tensor[] values) | |||
| { | |||
| _values = new HashSet<string>(); | |||
| @@ -239,6 +326,332 @@ namespace Tensorflow.Operations | |||
| _values.Add(x.name); | |||
| } | |||
| protected override void _AddOpInternal(Operation op) | |||
| { | |||
| if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad") | |||
| { | |||
| } | |||
| Operation[] external_inputs = new Operation[0]; | |||
| Operation[] control_inputs = new Operation[0]; | |||
| if (op.inputs.Length == 0) | |||
| { | |||
| // Remove any external control dependency on this op | |||
| (control_inputs, external_inputs) = _RemoveExternalControlEdges(op); | |||
| if (control_inputs.Length == 0) | |||
| op._add_control_input(GetControlPivot().op); | |||
| foreach (var x in op.outputs) | |||
| _values.Add(x.name); | |||
| } | |||
| else | |||
| { | |||
| foreach (var index in range(len(op.inputs))) | |||
| { | |||
| var x = op.inputs[index]; | |||
| var real_x = AddValue(x); | |||
| if (real_x != x) | |||
| op._update_input(index, real_x); | |||
| } | |||
| // Remove any external control dependency on this op. | |||
| (_, external_inputs) = _RemoveExternalControlEdges(op); | |||
| // Add a control dependency to prevent loop invariants from | |||
| // enabling ops that should not be executed. | |||
| _MaybeAddControlDependency(op); | |||
| foreach (Tensor x in op.outputs) | |||
| _values.Add(x.name); | |||
| } | |||
| if (external_inputs.Length > 0) | |||
| { | |||
| throw new NotImplementedException("external_inputs.Length > 0"); | |||
| } | |||
| if (_outer_context != null || !IsLoopExit(op)) | |||
| foreach (Tensor x in op.outputs) | |||
| op.graph.prevent_feeding(x); | |||
| if (_outer_context != null) | |||
| _outer_context.AddInnerOp(op); | |||
| } | |||
| protected void _MaybeAddControlDependency(Operation op) | |||
| { | |||
| // Determines if `op` needs a control dependency. | |||
| Func<Operation, bool> _IsOpFree = (op1) => | |||
| { | |||
| if (op1.control_inputs.Length > 0) | |||
| return false; | |||
| if (op1.type == "SymbolicGradient") | |||
| return true; | |||
| foreach (Tensor x in op1.inputs) | |||
| if (!control_flow_util.IsLoopConstantEnter(x.op)) | |||
| return false; | |||
| return true; | |||
| }; | |||
| if (_IsOpFree(op)) | |||
| op._add_control_input(GetControlPivot().op); | |||
| } | |||
| private Tensor GetControlPivot() | |||
| { | |||
| if (_pivot_for_body != null) | |||
| return _pivot_for_body; | |||
| return _pivot_for_pred; | |||
| } | |||
| public override void AddOp(Operation op) | |||
| { | |||
| _AddOpInternal(op); | |||
| } | |||
| /// <summary> | |||
| /// Adds a loop that counts the number of iterations. | |||
| /// </summary> | |||
| /// <param name="outer_grad_state">The outer grad state. None if not nested.</param> | |||
| /// <returns>The number of iterations taken by the forward loop and the loop index.</returns> | |||
| public (Tensor, Tensor) AddForwardLoopCounter(GradLoopState outer_grad_state) | |||
| { | |||
| var n = constant_op.constant(0, name: "f_count"); | |||
| if (outer_grad_state != null) | |||
| throw new NotImplementedException("AddForwardLoopCounter"); | |||
| Enter(); | |||
| AddName(n.name); | |||
| var enter_n = _Enter(n, | |||
| _name, | |||
| is_constant: false, | |||
| parallel_iterations: _parallel_iterations, | |||
| name: "f_count"); | |||
| _loop_enters.Add(enter_n); | |||
| var m1 = merge(new[] { enter_n, enter_n }); | |||
| var merge_n = m1[0]; | |||
| var switch_n = @switch (merge_n, _pivot); | |||
| var index = math_ops.add(switch_n[1], 1); | |||
| var next_n = _NextIteration(index); | |||
| merge_n.op._update_input(1, next_n); | |||
| var total_iterations = exit(switch_n[0], name: "f_count"); | |||
| loop_exits.append(total_iterations); | |||
| ExitResult(new[] { total_iterations }); | |||
| Exit(); | |||
| return (total_iterations, next_n); | |||
| } | |||
| /// <summary> | |||
| /// Add an accumulation loop for every loop invariant. | |||
| /// </summary> | |||
| /// <param name="op">The Enter op for a loop invariant.</param> | |||
| /// <param name="grad">The partial gradient of an iteration for a loop invariant.</param> | |||
| /// <returns>The gradient for a loop invariant.</returns> | |||
| public Tensor AddBackpropAccumulator(Operation op, Tensor grad) | |||
| { | |||
| Tensor acc = null; | |||
| Exit(); | |||
| // Create a zeros tensor with the right shape for acc. If we don't | |||
| // know the full shape statically, we will have to get the shape | |||
| // dynamically from the forward inference. Getting the shape right | |||
| // for the zeros is only needed for the base case when the loop exits | |||
| // without running any iterations. | |||
| var shape = grad.TensorShape; | |||
| if (shape.is_fully_defined()) | |||
| { | |||
| if (outer_context != null) | |||
| outer_context.Enter(); | |||
| acc = constant_op.constant(0, grad.dtype, shape: shape, name: "b_acc"); | |||
| if (outer_context != null) | |||
| outer_context.Exit(); | |||
| } | |||
| else | |||
| { | |||
| var value = op.inputs[0]; | |||
| if(outer_context is WhileContext wc) | |||
| { | |||
| // We are in a nested while loop. | |||
| var forward_ctxt = grad_state.forward_context; | |||
| forward_ctxt.outer_context.Enter(); | |||
| var zeros_shape = array_ops.shape_internal(value, optimize: false); | |||
| forward_ctxt.outer_context.Exit(); | |||
| var outer_grad_state = grad_state.outer_grad_state; | |||
| var history_zeros_shape = outer_grad_state.AddForwardAccumulator(zeros_shape); | |||
| outer_context.Enter(); | |||
| var real_shape = outer_grad_state.AddBackpropAccumulatedValue( | |||
| history_zeros_shape, zeros_shape); | |||
| acc = array_ops.zeros(real_shape, grad.dtype); | |||
| outer_context.Exit(); | |||
| } | |||
| else | |||
| { | |||
| if (outer_context != null) | |||
| outer_context.Enter(); | |||
| var zeros_shape = array_ops.shape_internal(value, optimize: false); | |||
| acc = array_ops.zeros(zeros_shape, grad.dtype); | |||
| if (outer_context != null) | |||
| outer_context.Exit(); | |||
| } | |||
| throw new NotImplementedException("AddBackpropAccumulator"); | |||
| } | |||
| Enter(); | |||
| AddName(acc.name); | |||
| var enter_acc = _Enter( | |||
| acc, | |||
| _name, | |||
| is_constant: false, | |||
| parallel_iterations: _parallel_iterations, | |||
| name: "b_acc"); | |||
| loop_enters.append(enter_acc); | |||
| var merge_acc = merge(new[] { enter_acc, enter_acc }, name: "b_acc")[0]; | |||
| var switch_result = @switch(merge_acc, _pivot); | |||
| var (switch_acc_false, switch_acc_true) = (switch_result[0], switch_result[1]); | |||
| var add_acc = math_ops.add(switch_acc_true, grad); | |||
| var next_acc = _NextIteration(add_acc); | |||
| merge_acc.op._update_input(1, next_acc); | |||
| var result_acc = exit(switch_acc_false, name: "b_acc"); | |||
| loop_exits.append(result_acc); | |||
| ExitResult(new[] { result_acc }); | |||
| return result_acc; | |||
| } | |||
| /// <summary> | |||
| /// Add the backprop loop that controls the iterations. | |||
| /// </summary> | |||
| /// <param name="count">The number of iterations for backprop.</param> | |||
| /// <param name="outer_grad_state">The outer grad state. None if not nested.</param> | |||
| /// <returns>The loop index.</returns> | |||
| public Tensor AddBackpropLoopCounter(Tensor count, GradLoopState outer_grad_state) | |||
| { | |||
| Tensor one = null; | |||
| var in_separate_functions = count.graph != ops.get_default_graph(); | |||
| if (in_separate_functions) | |||
| // Brings the count into this graph | |||
| count = array_ops.identity(count); | |||
| else | |||
| one = constant_op.constant(1, name: "b_count"); | |||
| Enter(); | |||
| AddName(count.name); | |||
| var enter_count = _Enter( | |||
| count, | |||
| _name, | |||
| is_constant: false, | |||
| parallel_iterations: _parallel_iterations, | |||
| name: "b_count"); | |||
| loop_enters.append(enter_count); | |||
| var merge_count = merge(new[] { enter_count, enter_count })[0]; | |||
| _pivot_for_pred = merge_count; | |||
| if (in_separate_functions) | |||
| one = constant_op.constant(1, name: "b_count"); | |||
| var pred = math_ops.greater_equal(merge_count, one); | |||
| _pivot = gen_control_flow_ops.loop_cond(pred, name: "b_count"); | |||
| var switch_count = @switch(merge_count, _pivot); | |||
| var index = math_ops.subtract(switch_count[1], one); | |||
| _pivot_for_body = index; | |||
| var next_count = _NextIteration(index); | |||
| merge_count.op._update_input(1, next_count); | |||
| var final_zero = exit(switch_count[0], name: "b_count"); | |||
| loop_exits.append(final_zero); | |||
| // Force the stack pops of i-th execution of an inner loop to be ordered | |||
| // before the pops of (i+1)-th execution of the same inner loop. | |||
| if (outer_grad_state != null) | |||
| throw new NotImplementedException("outer_grad_state"); | |||
| //outer_grad_state.grad_sync._add_control_input(final_zero.op); | |||
| ExitResult(new[] { final_zero }); | |||
| Exit(); | |||
| return next_count; | |||
| } | |||
| /// <summary> | |||
| /// Add `val` to the current context and its outer context recursively. | |||
| /// </summary> | |||
| /// <param name="val"></param> | |||
| /// <returns></returns> | |||
| public override Tensor AddValue(Tensor val) | |||
| { | |||
| var result = val; | |||
| var new_value = !_values.Contains(val.name); | |||
| new_value &= val.op._get_control_flow_context() != this; | |||
| if (new_value) | |||
| { | |||
| _values.Add(val.name); | |||
| // If we are in a grad context and val is from its forward context, | |||
| // use GetRealValue(), which adds the logic to save the history of | |||
| // val in forward. | |||
| var grad_ctxt = ops.get_default_graph()._get_control_flow_context(); | |||
| if(grad_ctxt != null) | |||
| { | |||
| grad_ctxt = grad_ctxt.GetWhileContext(); | |||
| if (grad_ctxt.grad_state != null) | |||
| { | |||
| var forward_ctxt = val.op.GetWhileContext(); | |||
| if (control_flow_util.IsLoopExit(val.op)) | |||
| { | |||
| forward_ctxt = forward_ctxt.outer_context as WhileContext; | |||
| if (forward_ctxt != null) | |||
| forward_ctxt = forward_ctxt.GetWhileContext(); | |||
| throw new NotImplementedException("control_flow_util.IsLoopExit"); | |||
| } | |||
| if(forward_ctxt == grad_ctxt.grad_state.forward_context) | |||
| { | |||
| var real_val = grad_ctxt.grad_state.GetRealValue(val); | |||
| _external_values[val.name] = real_val; | |||
| return real_val; | |||
| } | |||
| } | |||
| } | |||
| if (_outer_context != null) | |||
| result = _outer_context.AddValue(val); | |||
| // Create an Enter to make `result` known to this loop context. | |||
| Tensor enter = null; | |||
| tf_with(ops.control_dependencies(null), delegate | |||
| { | |||
| enter = control_flow_ops._Enter( | |||
| result, | |||
| _name, | |||
| is_constant: true, | |||
| parallel_iterations: _parallel_iterations); | |||
| enter.graph.prevent_feeding(enter); | |||
| if (_outer_context != null) | |||
| _outer_context.AddInnerOp(enter.op); | |||
| }); | |||
| // Fix the control inputs and control flow context of these enter ops. | |||
| _FixControlInputsAndContext(new[] { enter }); | |||
| // Add `enter` in this context. | |||
| _values.Add(enter.name); | |||
| _external_values[val.name] = enter; | |||
| result = enter; | |||
| } | |||
| else | |||
| { | |||
| var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null; | |||
| if (actual_val != null) | |||
| result = actual_val as Tensor; | |||
| } | |||
| return result; | |||
| } | |||
| public override bool IsWhileContext() | |||
| => true; | |||
| public override WhileContext GetWhileContext() | |||
| { | |||
| return this; | |||
| @@ -16,6 +16,7 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Operations.Initializers | |||
| { | |||
| @@ -16,9 +16,9 @@ | |||
| namespace Tensorflow | |||
| { | |||
| public class LayerRNNCell : RNNCell | |||
| public class LayerRnnCell : RnnCell | |||
| { | |||
| public LayerRNNCell(bool? _reuse = null, | |||
| public LayerRnnCell(bool? _reuse = null, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse, | |||
| name: name, | |||
| @@ -0,0 +1,49 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable<BodyItemInRnnWhileLoop> | |||
| { | |||
| /// <summary> | |||
| /// int32 scalar Tensor. | |||
| /// </summary> | |||
| public Tensor time { get; set; } | |||
| /// <summary> | |||
| /// List of `TensorArray`s that represent the output. | |||
| /// </summary> | |||
| public TensorArray[] output_ta_t { get; set; } | |||
| /// <summary> | |||
| /// nested tuple of vector tensors that represent the state. | |||
| /// </summary> | |||
| public Tensor state { get; set; } | |||
| public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state) | |||
| { | |||
| this.time = time; | |||
| this.output_ta_t = output_ta_t; | |||
| this.state = state; | |||
| } | |||
| public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item) | |||
| => (item.time, item.output_ta_t, item.state); | |||
| public object[] Flatten() | |||
| { | |||
| var elements = new List<object> { time }; | |||
| elements.AddRange(output_ta_t); | |||
| elements.Add(state); | |||
| return elements.ToArray(); | |||
| } | |||
| public BodyItemInRnnWhileLoop Pack(object[] sequences) | |||
| { | |||
| time = sequences[0] as Tensor; | |||
| output_ta_t = new[] { sequences[1] as TensorArray }; | |||
| state = sequences[2] as Tensor; | |||
| return new BodyItemInRnnWhileLoop(time, output_ta_t, state); | |||
| } | |||
| } | |||
| } | |||
| @@ -244,7 +244,27 @@ namespace Tensorflow.Operations | |||
| logits | |||
| }); | |||
| return _op.outputs[0]; | |||
| return _op.output; | |||
| } | |||
| /// <summary> | |||
| /// Says whether the targets are in the top `K` predictions. | |||
| /// </summary> | |||
| /// <param name="predictions"></param> | |||
| /// <param name="targets"></param> | |||
| /// <param name="k"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns>A `Tensor` of type `bool`.</returns> | |||
| public static Tensor in_top_kv2(Tensor predictions, Tensor targets, int k, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("InTopKV2", name: name, args: new | |||
| { | |||
| predictions, | |||
| targets, | |||
| k | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| @@ -24,12 +25,12 @@ namespace Tensorflow.Operations | |||
| { | |||
| internal class rnn | |||
| { | |||
| public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor, | |||
| public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, | |||
| Tensor sequence_length = null, Tensor initial_state = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | |||
| { | |||
| tf_with(tf.variable_scope("rnn"), scope => | |||
| return tf_with(tf.variable_scope("rnn"), scope => | |||
| { | |||
| VariableScope varscope = scope; | |||
| var flat_input = nest.flatten(inputs_tensor); | |||
| @@ -63,9 +64,12 @@ namespace Tensorflow.Operations | |||
| swap_memory: swap_memory, | |||
| sequence_length: sequence_length, | |||
| dtype: dtype); | |||
| }); | |||
| throw new NotImplementedException(""); | |||
| if (!time_major) | |||
| outputs = nest.map_structure(_transpose_batch_time, outputs); | |||
| return (outputs, final_state); | |||
| }); | |||
| } | |||
| /// <summary> | |||
| @@ -79,7 +83,7 @@ namespace Tensorflow.Operations | |||
| /// <param name="sequence_length"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <returns></returns> | |||
| private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state, | |||
| private static (Tensor, Tensor) _dynamic_rnn_loop(RnnCell cell, Tensor inputs, Tensor initial_state, | |||
| int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| var state = initial_state; | |||
| @@ -145,7 +149,7 @@ namespace Tensorflow.Operations | |||
| { | |||
| var ta = new TensorArray(dtype: dtype_, | |||
| size: time_steps, | |||
| element_shape: new[] { element_shape }, | |||
| element_shape: element_shape, | |||
| tensor_array_name: base_name + name); | |||
| return ta; | |||
| }; | |||
| @@ -170,29 +174,86 @@ namespace Tensorflow.Operations | |||
| flat_input_i.dtype)); | |||
| } | |||
| for (int i = 0; i < input_ta.Count; i++) | |||
| input_ta = zip(input_ta, flat_input).Select(x => | |||
| { | |||
| var (ta, input_) = (input_ta[0], flat_input[0]); | |||
| } | |||
| var (ta, input_) = (x.Item1, x.Item2); | |||
| return ta.unstack(input_); | |||
| }).ToList(); | |||
| } | |||
| // Make sure that we run at least 1 step, if necessary, to ensure | |||
| // the TensorArrays pick up the dynamic shape. | |||
| Tensor loop_bound; | |||
| Tensor loop_bound = null; | |||
| if (in_graph_mode) | |||
| loop_bound = math_ops.minimum( | |||
| time_steps, math_ops.maximum(1, max_sequence_length)); | |||
| /*Func<Tensor, Tensor> cond = (ctime) => | |||
| Func<BodyItemInRnnWhileLoop, Tensor> cond = (item) => | |||
| { | |||
| return null; | |||
| return item.time < loop_bound; | |||
| }; | |||
| // Take a time step of the dynamic RNN. | |||
| Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) => | |||
| { | |||
| Tensor[] input_t = null; | |||
| var (time1, output_ta_t, state1) = (item.time, item.output_ta_t, item.state); | |||
| if (in_graph_mode) | |||
| { | |||
| input_t = input_ta.Select(ta => ta.read(time1)).ToArray(); | |||
| // Restore some shape information | |||
| foreach (var (input_, shape) in zip(input_t, inputs_got_shape)) | |||
| input_.set_shape(shape[new Slice(1)]); | |||
| } | |||
| else | |||
| { | |||
| // input_t = tuple(ta[time.numpy()] for ta in input_ta) | |||
| } | |||
| var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t); | |||
| // Keras RNN cells only accept state as list, even if it's a single tensor. | |||
| // var is_keras_rnn_cell = _is_keras_rnn_cell(cell); | |||
| Tensor[] outputs = null; | |||
| if (sequence_length != null) | |||
| throw new NotImplementedException("sequence_length != null"); | |||
| else | |||
| outputs = cell.__call__(input_t_t, state: state1); | |||
| var (output, new_state) = (outputs[0], outputs[1]); | |||
| // Keras cells always wrap state as list, even if it's a single tensor. | |||
| // if(is_keras_rnn_cell && len(new_state)) == 1 | |||
| // Pack state if using state tuples | |||
| outputs = nest.flatten2(output).Select(x => x as Tensor).ToArray(); | |||
| output_ta_t = zip(output_ta_t, outputs).Select(x => | |||
| { | |||
| var(ta, @out) = (x.Item1, x.Item2); | |||
| return ta.write(item.time, @out); | |||
| }).ToArray(); | |||
| return new BodyItemInRnnWhileLoop(item.time + 1, output_ta_t, new_state); | |||
| }; | |||
| control_flow_ops.while_loop( | |||
| var while_loop_result = control_flow_ops.while_loop( | |||
| cond: cond, | |||
| body = );*/ | |||
| body: _time_step, | |||
| loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), | |||
| parallel_iterations: parallel_iterations, | |||
| maximum_iterations: time_steps, | |||
| swap_memory: swap_memory); | |||
| (_, TensorArray[] output_final_ta, Tensor final_state) = (while_loop_result.time, while_loop_result.output_ta_t, while_loop_result.state); | |||
| // Unpack final output if not using output tuples. | |||
| var final_outputs = output_final_ta.Select(ta => ta.stack()).ToArray(); | |||
| // Restore some shape information | |||
| foreach (var (output, output_size) in zip(final_outputs, flat_output_size)) | |||
| { | |||
| var shape = rnn_cell_impl._concat(new[] { const_time_steps, const_batch_size }, output_size, @static: true); | |||
| output.set_shape(shape); | |||
| } | |||
| throw new NotImplementedException(""); | |||
| return (final_outputs[0], final_state); | |||
| } | |||
| private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape) | |||
| @@ -20,8 +20,8 @@ namespace Tensorflow.Operations | |||
| { | |||
| public class rnn_cell_impl | |||
| { | |||
| public BasicRNNCell BasicRNNCell(int num_units) | |||
| => new BasicRNNCell(num_units); | |||
| public BasicRnnCell BasicRNNCell(int num_units) | |||
| => new BasicRnnCell(num_units); | |||
| public static Tensor _concat(Tensor prefix, int suffix, bool @static = false) | |||
| { | |||
| @@ -53,5 +53,34 @@ namespace Tensorflow.Operations | |||
| return array_ops.concat(new[] { p, s }, 0); | |||
| } | |||
| } | |||
| public static TensorShape _concat(int[] prefix, int suffix, bool @static = false) | |||
| { | |||
| var p = new TensorShape(prefix); | |||
| var p_static = prefix; | |||
| var p_tensor = p.is_fully_defined() ? constant_op.constant(p.as_list(), dtype: dtypes.int32) : null; | |||
| var s_tensor_shape = new TensorShape(suffix); | |||
| var s_static = s_tensor_shape.ndim > -1 ? | |||
| s_tensor_shape.dims : | |||
| null; | |||
| var s_tensor = s_tensor_shape.is_fully_defined() ? | |||
| constant_op.constant(s_tensor_shape.dims, dtype: dtypes.int32) : | |||
| null; | |||
| if (@static) | |||
| { | |||
| if (p_static is null) return null; | |||
| var shape = new TensorShape(p_static).concatenate(s_static); | |||
| return shape; | |||
| } | |||
| else | |||
| { | |||
| if (p is null || s_tensor is null) | |||
| throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}"); | |||
| // return array_ops.concat(new[] { p_tensor, s_tensor }, 0); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -228,6 +228,15 @@ namespace Tensorflow | |||
| output_types.AddRange(types); | |||
| } | |||
| // We add an explicit colocation constraint between | |||
| // the newly created op and any of its reference-typed inputs. | |||
| var must_colocate_inputs = zip(op_def.InputArg, inputs) | |||
| .Where(x => x.Item1.IsRef) | |||
| .Select(x => x.Item2) | |||
| .ToArray(); | |||
| _MaybeColocateWith(must_colocate_inputs); | |||
| // Add Op to graph | |||
| var op = g.create_op(op_type_name, | |||
| inputs.ToArray(), | |||
| @@ -241,6 +250,11 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| private void _MaybeColocateWith(ITensorOrOperation[] inputs) | |||
| { | |||
| } | |||
| private void SetAttrs(string op_type_name, | |||
| ArgDef input_arg, | |||
| OpDef op_def, | |||
| @@ -15,6 +15,7 @@ | |||
| ******************************************************************************/ | |||
| using Tensorflow.Operations; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -30,11 +31,8 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public void _control_flow_post_processing() | |||
| { | |||
| foreach(var input_tensor in inputs) | |||
| { | |||
| //TODO: implement below code dependency | |||
| //control_flow_util.CheckInputFromValidContext(this, input_tensor.op); | |||
| } | |||
| foreach(Tensor input_tensor in inputs) | |||
| control_flow_util.CheckInputFromValidContext(this, input_tensor.op); | |||
| if (_control_flow_context != null) | |||
| _control_flow_context.AddOp(this); | |||
| @@ -54,6 +52,10 @@ namespace Tensorflow | |||
| public void _set_control_flow_context(ControlFlowContext ctx) | |||
| { | |||
| if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc")) | |||
| { | |||
| } | |||
| _control_flow_context = ctx; | |||
| } | |||
| @@ -61,5 +63,10 @@ namespace Tensorflow | |||
| { | |||
| return _control_flow_context; | |||
| } | |||
| public WhileContext GetWhileContext() | |||
| { | |||
| return _control_flow_context as WhileContext; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,10 +14,12 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Newtonsoft.Json; | |||
| using System; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| #if SERIALIZABLE | |||
| using Newtonsoft.Json; | |||
| #endif | |||
| namespace Tensorflow | |||
| { | |||
| @@ -42,14 +44,14 @@ namespace Tensorflow | |||
| [JsonIgnore] | |||
| #endif | |||
| public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||
| private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); | |||
| private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray(); | |||
| private InputList _inputs; | |||
| private InputList _inputs_val; | |||
| public InputList inputs | |||
| { | |||
| get | |||
| { | |||
| if (_inputs == null) | |||
| if (_inputs_val == null) | |||
| { | |||
| var retval = new Tensor[NumInputs]; | |||
| @@ -60,10 +62,10 @@ namespace Tensorflow | |||
| retval[i] = op.outputs[tf_output.index]; | |||
| } | |||
| _inputs = new InputList(retval); | |||
| _inputs_val = new InputList(retval); | |||
| } | |||
| return _inputs; | |||
| return _inputs_val; | |||
| } | |||
| } | |||
| @@ -15,17 +15,14 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Linq; | |||
| using System.Collections.Generic; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Operation | |||
| { | |||
| // cache the mapping between managed and unmanaged op | |||
| // some data is stored in managed instance, so when | |||
| // create Operation by IntPtr, it will lost some data. | |||
| private static Dictionary<IntPtr, Operation> OpInstances = new Dictionary<IntPtr, Operation>(); | |||
| /// <summary> | |||
| /// Get operation by handle | |||
| /// </summary> | |||
| @@ -33,9 +30,17 @@ namespace Tensorflow | |||
| /// <returns></returns> | |||
| public Operation GetOperation(IntPtr handle) | |||
| { | |||
| return OpInstances.ContainsKey(handle) ? | |||
| OpInstances[handle] : | |||
| new Operation(handle); | |||
| var nodes = tf.get_default_graph()._nodes_by_name; | |||
| foreach(var node in nodes.Values) | |||
| { | |||
| if (node is Operation op) | |||
| { | |||
| if (op == handle) | |||
| return op; | |||
| } | |||
| } | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,10 +14,12 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Newtonsoft.Json; | |||
| using System; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| #if SERIALIZABLE | |||
| using Newtonsoft.Json; | |||
| #endif | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -21,8 +21,9 @@ using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Linq; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -65,7 +66,7 @@ namespace Tensorflow | |||
| #if SERIALIZABLE | |||
| [JsonIgnore] | |||
| #endif | |||
| public int _id_value; | |||
| public int _id_value { get; set; } | |||
| #if SERIALIZABLE | |||
| [JsonIgnore] | |||
| #endif | |||
| @@ -77,6 +78,7 @@ namespace Tensorflow | |||
| #if SERIALIZABLE | |||
| [JsonIgnore] | |||
| #endif | |||
| bool _is_stateful; | |||
| public NodeDef node_def | |||
| { | |||
| get | |||
| @@ -104,7 +106,6 @@ namespace Tensorflow | |||
| _control_flow_context = _graph._get_control_flow_context(); | |||
| // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | |||
| OpInstances[_handle] = this; | |||
| } | |||
| /*public Operation(Graph g, string opType, string oper_name) | |||
| @@ -172,16 +173,19 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| _id_value = _graph._next_id(); | |||
| // Dict mapping op name to file and line information for op colocation | |||
| // context managers. | |||
| _control_flow_context = graph._get_control_flow_context(); | |||
| _control_flow_context = graph._get_control_flow_context(); | |||
| // This will be set by self.inputs. | |||
| if (op_def == null) | |||
| op_def = g.GetOpDef(node_def.Op); | |||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||
| _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||
| _is_stateful = op_def.IsStateful; | |||
| // Initialize self._outputs. | |||
| output_types = new TF_DataType[NumOutputs]; | |||
| @@ -196,8 +200,6 @@ namespace Tensorflow | |||
| if (_handle != IntPtr.Zero) | |||
| _control_flow_post_processing(); | |||
| OpInstances[_handle] = this; | |||
| } | |||
| public void run(FeedItem[] feed_dict = null, Session session = null) | |||
| @@ -304,7 +306,7 @@ namespace Tensorflow | |||
| var output = tensor._as_tf_output(); | |||
| // Reset cached inputs. | |||
| _inputs = null; | |||
| _inputs_val = null; | |||
| // after the c_api call next time _inputs is accessed | |||
| // the updated inputs are reloaded from the c_api | |||
| lock (Locks.ProcessWide) | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||
| /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | |||
| /// for each `s` in `self.batch_size`. | |||
| /// </summary> | |||
| public abstract class RNNCell : Layers.Layer | |||
| public abstract class RnnCell : Layers.Layer | |||
| { | |||
| /// <summary> | |||
| /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | |||
| @@ -53,7 +53,7 @@ namespace Tensorflow | |||
| public virtual int output_size { get; } | |||
| public RNNCell(bool trainable = true, | |||
| public RnnCell(bool trainable = true, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| bool? _reuse = null) : base(trainable: trainable, | |||
| @@ -22,9 +22,10 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| internal class _GraphTensorArray | |||
| public class _GraphTensorArray | |||
| { | |||
| internal TF_DataType _dtype; | |||
| public TF_DataType dtype => _dtype; | |||
| /// <summary> | |||
| /// Used to keep track of what tensors the TensorArray should be | |||
| @@ -32,19 +33,22 @@ namespace Tensorflow.Operations | |||
| /// first tensor written to it. | |||
| /// </summary> | |||
| bool _colocate_with_first_write_call; | |||
| public bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||
| bool _infer_shape; | |||
| bool _dynamic_size; | |||
| List<TensorShape> _element_shape; | |||
| public bool infer_shape => _infer_shape; | |||
| public bool _dynamic_size; | |||
| public List<TensorShape> _element_shape; | |||
| List<Tensor> _colocate_with; | |||
| public List<Tensor> _colocate_with; | |||
| internal Tensor _handle; | |||
| public Tensor handle => _handle; | |||
| internal Tensor _flow; | |||
| public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||
| bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
| bool infer_shape = true, TensorShape[] element_shape = null, | |||
| bool infer_shape = true, TensorShape element_shape = null, | |||
| bool colocate_with_first_write_call = true, string name = null) | |||
| { | |||
| clear_after_read = clear_after_read ?? true; | |||
| @@ -68,7 +72,7 @@ namespace Tensorflow.Operations | |||
| else | |||
| { | |||
| _infer_shape = true; | |||
| _element_shape = new List<TensorShape> { }; | |||
| _element_shape = new List<TensorShape> { element_shape }; | |||
| } | |||
| tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => | |||
| @@ -135,7 +139,7 @@ namespace Tensorflow.Operations | |||
| var ta = new TensorArray(_dtype, | |||
| infer_shape:_infer_shape, | |||
| element_shape: _element_shape.ToArray(), | |||
| element_shape: _element_shape[0], | |||
| dynamic_size: _dynamic_size, | |||
| handle: _handle, | |||
| flow: flow_out, | |||
| @@ -155,5 +159,72 @@ namespace Tensorflow.Operations | |||
| { | |||
| _colocate_with.Add(value); | |||
| } | |||
| public Tensor read(Tensor index, string name = null) | |||
| { | |||
| var value = gen_data_flow_ops.tensor_array_read_v3( | |||
| handle: _handle, | |||
| index: index, | |||
| flow_in: _flow, | |||
| dtype: _dtype, | |||
| name: name); | |||
| if (_element_shape != null) | |||
| value.set_shape(_element_shape[0].dims); | |||
| return value; | |||
| } | |||
| public TensorArray write(Tensor index, Tensor value, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate | |||
| { | |||
| value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
| _maybe_colocate_with(value); | |||
| var flow_out = gen_data_flow_ops.tensor_array_write_v3( | |||
| handle: _handle, | |||
| index: index, | |||
| value: value, | |||
| flow_in: _flow, | |||
| name: name); | |||
| return tensor_array_ops.build_ta_with_new_flow(this, flow_out); | |||
| }); | |||
| } | |||
| private Tensor size(string name = null) | |||
| { | |||
| return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); | |||
| } | |||
| public Tensor stack(string name = null) | |||
| { | |||
| ops.colocate_with(_handle); | |||
| return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate | |||
| { | |||
| return gather(math_ops.range(0, size()), name: name); | |||
| }); | |||
| } | |||
| public Tensor gather(Tensor indices, string name = null) | |||
| { | |||
| var element_shape = new TensorShape(); | |||
| if (_element_shape.Count > 0) | |||
| element_shape = _element_shape[0]; | |||
| var value = gen_data_flow_ops.tensor_array_gather_v3( | |||
| handle: _handle, | |||
| indices: indices, | |||
| flow_in: _flow, | |||
| dtype: _dtype, | |||
| name: name, | |||
| element_shape: element_shape); | |||
| //if (element_shape != null) | |||
| //value.set_shape(-1, element_shape.dims); | |||
| return value; | |||
| } | |||
| } | |||
| } | |||
| @@ -21,6 +21,7 @@ using Tensorflow.Operations; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| using util = Tensorflow.control_flow_util; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -150,27 +151,50 @@ namespace Tensorflow | |||
| /// <param name="colocate_gradients_with_ops"></param> | |||
| public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | |||
| { | |||
| var flag = new List<Operation>(); | |||
| ControlFlowState loop_state = null; | |||
| foreach (var op in between_op_list) | |||
| int pos = 0; | |||
| while(pos < between_op_list.Count) | |||
| { | |||
| var op = between_op_list[pos]; | |||
| if (IsLoopExit(op)) | |||
| { | |||
| if(loop_state == null) | |||
| if (loop_state == null) | |||
| { | |||
| loop_state = new ControlFlowState(); | |||
| } | |||
| if (colocate_gradients_with_ops) | |||
| ops.colocate_with(op); | |||
| loop_state.AddWhileContext(op, between_op_list, between_ops); | |||
| } | |||
| pos++; | |||
| } | |||
| return loop_state; | |||
| } | |||
| public static bool IsLoopExit(Operation op) | |||
| => op.OpType == "Exit" || op.OpType == "RefExit"; | |||
| public static bool IsLoopSwitch(Operation op) | |||
| { | |||
| if(IsSwitch(op)) | |||
| { | |||
| var ctxt = op._get_control_flow_context(); | |||
| return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op); | |||
| } | |||
| return false; | |||
| } | |||
| public static bool IsCondSwitch(Operation op) | |||
| { | |||
| return op.OpType == "Exit" || op.OpType == "RefExit"; | |||
| throw new NotImplementedException("IsCondSwitch"); | |||
| } | |||
| public static bool IsSwitch(Operation op) | |||
| => op.type == "Switch" || op.type == "RefSwitch"; | |||
| public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "tuple", tensors), scope => | |||
| @@ -223,15 +247,10 @@ namespace Tensorflow | |||
| //TODO: missing original code | |||
| //if context.executing_eagerly(): | |||
| // return output_tensor | |||
| var values = new List<object>(); | |||
| values.AddRange(dependencies); | |||
| values.Add(output_tensor); | |||
| return tf_with(ops.name_scope(name, "control_dependency", values), scope => | |||
| return tf_with(ops.name_scope(name, "control_dependency", new { dependencies, output_tensor }), scope => | |||
| { | |||
| name = scope; | |||
| // TODO: missing original code | |||
| //with ops.colocate_with(output_tensor): | |||
| ops.colocate_with(output_tensor); | |||
| { | |||
| return tf_with(ops.control_dependencies(dependencies), ctl => | |||
| { | |||
| @@ -251,12 +270,16 @@ namespace Tensorflow | |||
| return gen_array_ops.identity(data, name: name); | |||
| } | |||
| public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null) | |||
| public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape[] shapes = null) | |||
| { | |||
| if (shapes == null) | |||
| return; | |||
| throw new NotImplementedException("_SetShapeInvariants"); | |||
| var flat_shapes = nest.flatten2(shapes); | |||
| foreach (var (inp, var, shape) in zip(input_vars, enter_vars, flat_shapes)) | |||
| { | |||
| var.set_shape(shape); | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -426,14 +449,15 @@ namespace Tensorflow | |||
| var merges = zip(res_f_flat, res_t_flat) | |||
| .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) | |||
| .Select(m => (Tensor)m) | |||
| .ToArray(); | |||
| merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); | |||
| var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||
| return merges[0]; | |||
| return new Tensor(IntPtr.Zero); | |||
| }); | |||
| } | |||
| @@ -473,22 +497,29 @@ namespace Tensorflow | |||
| var res_f_flat = res_f; | |||
| var merges = zip(res_f_flat, res_t_flat) | |||
| .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) | |||
| .Select(pair => merge(new [] { pair.Item1, pair.Item2 })) | |||
| .Select(m => (Tensor)m) | |||
| .ToArray(); | |||
| merges = _convert_flows_to_tensorarrays(orig_res_t, merges); | |||
| var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||
| return merges; | |||
| return new[] { new Tensor(IntPtr.Zero) }; | |||
| }); | |||
| } | |||
| public static Tensor[] _convert_flows_to_tensorarrays<T>(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) | |||
| public static ITensorOrTensorArray[] _convert_flows_to_tensorarrays(ITensorOrTensorArray[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) | |||
| { | |||
| // zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray(); | |||
| return tensors_or_flows; | |||
| return zip(tensors_or_tensorarrays, tensors_or_flows).Select(x => | |||
| { | |||
| var (ta, t_or_flow) = (x.Item1, x.Item2); | |||
| if (ta is TensorArray ta_1) | |||
| return tensor_array_ops.build_ta_with_new_flow(ta_1, t_or_flow) as ITensorOrTensorArray; | |||
| else | |||
| return t_or_flow as ITensorOrTensorArray; | |||
| }).ToArray(); | |||
| } | |||
| /// <summary> | |||
| @@ -508,7 +539,7 @@ namespace Tensorflow | |||
| /// <param name="inputs">inputs: The input tensors, at most one of which is available.</param> | |||
| /// <param name="name">A name for this operation (optional).</param> | |||
| /// <returns></returns> | |||
| public static Tensor merge(Tensor[] inputs, string name = null) | |||
| public static MergeOutput merge(Tensor[] inputs, string name = null) | |||
| { | |||
| if (inputs.Any(x => x == null)) | |||
| throw new ValueError($"At least one of the merge inputs is null: {inputs}"); | |||
| @@ -518,7 +549,7 @@ namespace Tensorflow | |||
| inputs = inputs.Select(inp => | |||
| ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) | |||
| .ToArray(); | |||
| return gen_control_flow_ops.merge(inputs, name)[0]; | |||
| return gen_control_flow_ops.merge(inputs, name); | |||
| }); | |||
| } | |||
| @@ -591,18 +622,18 @@ namespace Tensorflow | |||
| /// <param name="body"></param> | |||
| /// <param name="loop_vars"></param> | |||
| /// <param name="i"></param> | |||
| public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
| TensorShape shape_invariants = null, | |||
| public static TItem while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars, | |||
| TensorShape[] shape_invariants = null, | |||
| int parallel_iterations = 10, | |||
| bool back_prop = true, | |||
| bool swap_memory = false, | |||
| string name = null, | |||
| int? maximum_iterations = null, | |||
| Tensor maximum_iterations = null, | |||
| bool return_same_structure = false) | |||
| { | |||
| tf_with(ops.name_scope(name, "while", loop_vars), scope => | |||
| return tf_with(ops.name_scope(name, "while", loop_vars), scope => | |||
| { | |||
| if (loop_vars == null || loop_vars.Length == 0) | |||
| if (loop_vars == null) | |||
| throw new ValueError("No loop variables provided"); | |||
| if (cond == null) | |||
| throw new ValueError("cond must be callable."); | |||
| @@ -611,6 +642,38 @@ namespace Tensorflow | |||
| if (parallel_iterations < 1) | |||
| throw new ValueError("parallel_iterations must be a positive integer."); | |||
| var try_to_pack = loop_vars is Tensor && !return_same_structure; | |||
| var counter = constant_op.constant(0, dtype: maximum_iterations.dtype, name: "iteration_counter"); | |||
| var orig_cond = cond; | |||
| var orig_body = body; | |||
| LoopVar<TItem> loop_vars_1 = null; | |||
| Func<LoopVar<TItem>, LoopVar<TItem>> body_buildloop = null; | |||
| Func<LoopVar<TItem>, Tensor> cond_buildloop = null; | |||
| if (try_to_pack) | |||
| { | |||
| } | |||
| else | |||
| { | |||
| loop_vars_1 = new LoopVar<TItem>(counter, loop_vars); | |||
| cond_buildloop = (item) => | |||
| { | |||
| var (i, lv) = (item.Counter, item.Item); | |||
| var oc = orig_cond(lv); | |||
| return math_ops.logical_and(i < maximum_iterations, oc); | |||
| }; | |||
| body_buildloop = (item) => | |||
| { | |||
| var (i, lv) = (item.Counter, item.Item); | |||
| var ob = orig_body(lv); | |||
| return new LoopVar<TItem>(i + 1, ob); | |||
| }; | |||
| } | |||
| try_to_pack = false; | |||
| var loop_context = new WhileContext( | |||
| maximum_iterations: maximum_iterations, | |||
| parallel_iterations: parallel_iterations, | |||
| @@ -620,17 +683,46 @@ namespace Tensorflow | |||
| if (loop_context.outer_context == null) | |||
| ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); | |||
| var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | |||
| var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants, | |||
| return_same_structure); | |||
| if (maximum_iterations != null) | |||
| return results[1]; | |||
| else | |||
| return results[0]; | |||
| //if (maximum_iterations != null) | |||
| return results.Item; | |||
| //else | |||
| //return results; | |||
| }); | |||
| throw new NotImplementedException("while_loop"); | |||
| } | |||
| /// <summary> | |||
| /// Creates or finds a child frame, and makes `data` available to it. | |||
| /// </summary> | |||
| /// <param name="data"></param> | |||
| /// <param name="frame_name"></param> | |||
| /// <param name="is_constant"></param> | |||
| /// <param name="parallel_iterations"></param> | |||
| /// <param name="use_ref"></param> | |||
| /// <param name="use_input_shape"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor _Enter(Tensor data, string frame_name, | |||
| bool is_constant = false, | |||
| int parallel_iterations = 10, | |||
| bool use_ref = true, | |||
| bool use_input_shape = true, | |||
| string name = null) | |||
| { | |||
| Tensor result; | |||
| data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||
| if (data.dtype.is_ref_dtype() && use_ref) | |||
| throw new NotImplementedException("_Enter"); | |||
| else | |||
| result = gen_control_flow_ops.enter( | |||
| data, frame_name, is_constant, parallel_iterations, name: name); | |||
| if (use_input_shape) | |||
| result.set_shape(data.TensorShape); | |||
| return result; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,7 +14,10 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Operations; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -28,6 +31,26 @@ namespace Tensorflow | |||
| public static bool IsLoopExit(Operation op) | |||
| { | |||
| return op.type == "Exit" || op.type == "RefExit"; | |||
| } | |||
| /// <summary> | |||
| /// Returns true if `op` is an Enter. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <returns></returns> | |||
| public static bool IsLoopEnter(Operation op) | |||
| { | |||
| return op.type == "Enter" || op.type == "RefEnter"; | |||
| } | |||
| /// <summary> | |||
| /// Return true iff op is a loop invariant. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <returns></returns> | |||
| public static bool IsLoopConstantEnter(Operation op) | |||
| { | |||
| return IsLoopEnter(op) && op.get_attr<bool>("is_constant"); | |||
| } | |||
| /// <summary> | |||
| @@ -38,6 +61,45 @@ namespace Tensorflow | |||
| public static bool IsSwitch(Operation op) | |||
| { | |||
| return op.type == "Switch" || op.type == "RefSwitch"; | |||
| } | |||
| public static WhileContext GetWhileContext(Operation op) | |||
| => op.GetWhileContext(); | |||
| public static bool IsCondSwitch(Operation op) | |||
| { | |||
| if (!IsSwitch(op)) | |||
| return false; | |||
| if (op.outputs == null || op.outputs.Length == 0) | |||
| return false; | |||
| // Switch nodes are not part of the cond control flow context that they | |||
| // represent, so consider the consumers of its outputs to determine if it is | |||
| // cond switch or not. A switch is a cond switch iff all its consumers are in | |||
| // cond contexts. | |||
| var is_cond_switch = true; | |||
| foreach(var o in op.outputs) | |||
| { | |||
| foreach(var c in o.consumers()) | |||
| { | |||
| var ctxt = c._get_control_flow_context(); | |||
| if (IsLoopEnter(c)) | |||
| ctxt = ctxt.outer_context; | |||
| is_cond_switch = is_cond_switch &&(ctxt != null && ctxt.IsCondContext()); | |||
| } | |||
| } | |||
| return is_cond_switch; | |||
| } | |||
| public static bool IsLoopSwitch(Operation op) | |||
| { | |||
| if (IsSwitch(op)) | |||
| { | |||
| var ctxt = op._get_control_flow_context(); | |||
| return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op); | |||
| } | |||
| return false; | |||
| } | |||
| /// <summary> | |||
| @@ -53,5 +115,76 @@ namespace Tensorflow | |||
| ctxt = ctxt.outer_context; | |||
| return ctxt; | |||
| } | |||
| public static void CheckInputFromValidContext(Operation op, Operation input_op) | |||
| { | |||
| var op_ctxt = op._get_control_flow_context(); | |||
| var input_ctxt = GetOutputContext(input_op); | |||
| var valid = false; | |||
| if (input_ctxt == null) | |||
| valid = true; | |||
| else if (op_ctxt == input_ctxt) | |||
| valid = true; | |||
| else | |||
| { | |||
| var while_ctxt = GetContainingWhileContext(op_ctxt); | |||
| var input_while_ctxt = GetContainingWhileContext(input_ctxt); | |||
| if (while_ctxt == null) | |||
| { | |||
| throw new NotImplementedException("CheckInputFromValidContext"); | |||
| } | |||
| else if (IsContainingContext(while_ctxt, input_while_ctxt)) | |||
| { | |||
| // input_op is in a while loop which contains op's while loop (or not in a | |||
| // while loop at all). | |||
| valid = true; | |||
| } | |||
| else if (while_ctxt.grad_state != null && | |||
| IsContainingContext(while_ctxt.grad_state.forward_context, | |||
| input_while_ctxt)) | |||
| { | |||
| valid = true; | |||
| } | |||
| else | |||
| throw new NotImplementedException("CheckInputFromValidContext"); | |||
| } | |||
| if (!valid) | |||
| { | |||
| throw new NotImplementedException("CheckInputFromValidContext"); | |||
| } | |||
| } | |||
| public static Operation GetLoopConstantEnter(Tensor value) | |||
| { | |||
| var id_ops = new string[] { "Switch", "RefSwitch", "Identity", "RefIdentity" }; | |||
| var op = value.op; | |||
| while (id_ops.Contains(op.type)) | |||
| op = op.inputs[0].op; | |||
| return IsLoopConstantEnter(op) ? op : null; | |||
| } | |||
| public static bool IsContainingContext(WhileContext ctxt, WhileContext maybe_containing_ctxt) | |||
| { | |||
| while(ctxt != maybe_containing_ctxt) | |||
| { | |||
| if (ctxt == null) | |||
| return false; | |||
| ctxt = ctxt.outer_context as WhileContext; | |||
| } | |||
| return true; | |||
| } | |||
| public static WhileContext GetContainingWhileContext(ControlFlowContext ctxt, ControlFlowContext stop_ctxt = null) | |||
| { | |||
| while (ctxt != null) | |||
| { | |||
| if (ctxt.IsWhileContext() || ctxt == stop_ctxt) | |||
| return ctxt as WhileContext; | |||
| ctxt = ctxt.outer_context; | |||
| } | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| @@ -183,7 +183,7 @@ namespace Tensorflow | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("Identity", name, new { input }); | |||
| return _op.outputs[0]; | |||
| return _op.output; | |||
| } | |||
| public static Tensor invert_permutation(Tensor x, string name = null) | |||
| @@ -14,12 +14,23 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| public class gen_control_flow_ops | |||
| { | |||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| public static Operation control_trigger(string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("ControlTrigger", name, new | |||
| { | |||
| }); | |||
| return _op; | |||
| } | |||
| /// <summary> | |||
| /// Creates or finds a child frame, and makes `data` available to the child frame. | |||
| /// </summary> | |||
| @@ -148,18 +159,18 @@ namespace Tensorflow | |||
| return new []{_op.outputs[0], _op.outputs[1]}; | |||
| } | |||
| public static Tensor[] ref_merge(Tensor[] inputs, string name = null) | |||
| public static MergeOutput ref_merge(Tensor[] inputs, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); | |||
| return _op.outputs; | |||
| return new MergeOutput(_op.outputs); | |||
| } | |||
| public static Tensor[] merge(Tensor[] inputs, string name = null) | |||
| public static MergeOutput merge(Tensor[] inputs, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); | |||
| return _op.outputs; | |||
| return new MergeOutput(_op.outputs); | |||
| } | |||
| } | |||
| } | |||
| @@ -28,12 +28,9 @@ namespace Tensorflow | |||
| } | |||
| public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid, | |||
| TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||
| bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) | |||
| TensorShape element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||
| bool identical_element_shapes = false, string tensor_array_name = "", string name = null) | |||
| { | |||
| if (tensor_array_name == null) | |||
| tensor_array_name = string.Empty; | |||
| var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new | |||
| { | |||
| size, | |||
| @@ -201,5 +198,103 @@ namespace Tensorflow | |||
| return _op.outputs; | |||
| } | |||
| /// <summary> | |||
| /// Read an element from the TensorArray into output `value`. | |||
| /// </summary> | |||
| /// <param name="handle"></param> | |||
| /// <param name="index"></param> | |||
| /// <param name="flow_in"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor tensor_array_read_v3(Tensor handle, Tensor index, Tensor flow_in, TF_DataType dtype, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("TensorArrayReadV3", name, new | |||
| { | |||
| handle, | |||
| index, | |||
| flow_in, | |||
| dtype | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public static Tensor tensor_array_write_v3(Tensor handle, Tensor index, Tensor value, Tensor flow_in, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("TensorArrayWriteV3", name, new | |||
| { | |||
| handle, | |||
| index, | |||
| value, | |||
| flow_in | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public static Tensor tensor_array_size_v3(Tensor handle, Tensor flow_in, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("TensorArraySizeV3", name, new | |||
| { | |||
| handle, | |||
| flow_in | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public static Tensor tensor_array_gather_v3(Tensor handle, Tensor indices, Tensor flow_in, | |||
| TF_DataType dtype, TensorShape element_shape = null, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("TensorArrayGatherV3", name, new | |||
| { | |||
| handle, | |||
| indices, | |||
| dtype, | |||
| element_shape, | |||
| flow_in | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public static Tensor stack_v2(Tensor max_size, TF_DataType elem_type, string stack_name = "", | |||
| string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("StackV2", name, new | |||
| { | |||
| max_size, | |||
| elem_type, | |||
| stack_name | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public static Tensor stack_push_v2(Tensor handle, Tensor elem, bool swap_memory = false, | |||
| string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("StackPushV2", name, new | |||
| { | |||
| handle, | |||
| elem, | |||
| swap_memory | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public static Tensor stack_pop_v2(Tensor handle, TF_DataType elem_type, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("StackPopV2", name, new | |||
| { | |||
| handle, | |||
| elem_type | |||
| }); | |||
| return _op.output; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,6 +14,8 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public static class gen_math_ops | |||
| @@ -280,7 +282,7 @@ namespace Tensorflow | |||
| /// <param name="dy"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor tanh_grad(Tensor y, Tensor dy, string name = "TanhGrad") | |||
| public static Tensor tanh_grad(Tensor y, Tensor dy, string name = null) | |||
| => _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; | |||
| public static Tensor floor(Tensor x, string name = null) | |||
| @@ -566,7 +568,7 @@ namespace Tensorflow | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("MatMul", name, args: new { a, b, transpose_a, transpose_b }); | |||
| return _op.outputs[0]; | |||
| return _op.output; | |||
| } | |||
| /// <summary> | |||
| @@ -159,6 +159,8 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| => gen_math_ops.greater_equal<Tx, Ty>(x, y, name: name); | |||
| public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| => gen_math_ops.equal(x, y, name: name); | |||
| @@ -543,6 +545,23 @@ namespace Tensorflow | |||
| public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| => gen_math_ops.maximum(x, y, name: name); | |||
| /// <summary> | |||
| /// Multiplies matrix `a` by matrix `b`, producing `a` * `b`. | |||
| /// </summary> | |||
| /// <param name="a"></param> | |||
| /// <param name="b"></param> | |||
| /// <param name="transpose_a">If `True`, `a` is transposed before multiplication.</param> | |||
| /// <param name="transpose_b">If `True`, `b` is transposed before multiplication.</param> | |||
| /// <param name="adjoint_a">If `True`, `a` is conjugated and transposed before multiplication.</param> | |||
| /// <param name="adjoint_b">If `True`, `b` is conjugated and transposed before multiplication.</param> | |||
| /// <param name="a_is_sparse">If `True`, `a` is treated as a sparse matrix.</param> | |||
| /// <param name="b_is_sparse">If `True`, `b` is treated as a sparse matrix.</param> | |||
| /// <param name="name">Name for the operation (optional).</param> | |||
| /// <returns> | |||
| /// A `Tensor` of the same type as `a` and `b` where each inner-most matrix is | |||
| /// the product of the corresponding matrices in `a` and `b`, e.g. if all | |||
| /// transpose or adjoint attributes are `False`: | |||
| /// </returns> | |||
| public static Tensor matmul(Tensor a, Tensor b, | |||
| bool transpose_a = false, bool transpose_b = false, | |||
| bool adjoint_a = false, bool adjoint_b = false, | |||
| @@ -111,6 +111,14 @@ namespace Tensorflow | |||
| return noise_shape; | |||
| } | |||
| public static Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "in_top_k"), delegate | |||
| { | |||
| return gen_nn_ops.in_top_kv2(predictions, targets, k, name: name); | |||
| }); | |||
| } | |||
| public static Tensor log_softmax(Tensor logits, int axis = -1, string name = null) | |||
| { | |||
| return _softmax(logits, gen_nn_ops.log_softmax, axis, name); | |||
| @@ -71,7 +71,7 @@ namespace Tensorflow | |||
| return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => | |||
| { | |||
| name = scope; | |||
| var tensorShape = _ShapeTensor(shape); | |||
| var tensorShape = tensor_util.shape_tensor(shape); | |||
| var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); | |||
| var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); | |||
| var rnd = gen_random_ops.random_uniform(tensorShape, dtype); | |||
| @@ -0,0 +1,52 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| public class tensor_array_ops | |||
| { | |||
| /// <summary> | |||
| /// Builds a TensorArray with a new `flow` tensor. | |||
| /// </summary> | |||
| /// <param name="old_ta"></param> | |||
| /// <param name="flow"></param> | |||
| /// <returns></returns> | |||
| public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) | |||
| { | |||
| var impl = old_ta._implementation; | |||
| var new_ta = new TensorArray( | |||
| dtype: impl.dtype, | |||
| handle: impl.handle, | |||
| flow: flow, | |||
| infer_shape: impl.infer_shape, | |||
| colocate_with_first_write_call: impl.colocate_with_first_write_call); | |||
| var new_impl = new_ta._implementation; | |||
| new_impl._dynamic_size = impl._dynamic_size; | |||
| new_impl._colocate_with = impl._colocate_with; | |||
| new_impl._element_shape = impl._element_shape; | |||
| return new_ta; | |||
| } | |||
| public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow) | |||
| { | |||
| var impl = old_ta; | |||
| var new_ta = new TensorArray( | |||
| dtype: impl.dtype, | |||
| handle: impl.handle, | |||
| flow: flow, | |||
| infer_shape: impl.infer_shape, | |||
| colocate_with_first_write_call: impl.colocate_with_first_write_call); | |||
| var new_impl = new_ta._implementation; | |||
| new_impl._dynamic_size = impl._dynamic_size; | |||
| new_impl._colocate_with = impl._colocate_with; | |||
| new_impl._element_shape = impl._element_shape; | |||
| return new_ta; | |||
| } | |||
| } | |||
| } | |||
| @@ -27,10 +27,10 @@ namespace Tensorflow { | |||
| "CiV0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29uZmlnLnByb3RvEgp0ZW5z", | |||
| "b3JmbG93Gip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Nvc3RfZ3JhcGgu", | |||
| "cHJvdG8aJXRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvZ3JhcGgucHJvdG8a", | |||
| "KnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvc3RlcF9zdGF0cy5wcm90bxok", | |||
| "dGVuc29yZmxvdy9jb3JlL3Byb3RvYnVmL2RlYnVnLnByb3RvGiZ0ZW5zb3Jm", | |||
| "bG93L2NvcmUvcHJvdG9idWYvY2x1c3Rlci5wcm90bxoudGVuc29yZmxvdy9j", | |||
| "b3JlL3Byb3RvYnVmL3Jld3JpdGVyX2NvbmZpZy5wcm90byKtBAoKR1BVT3B0", | |||
| "KnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvc3RlcF9zdGF0cy5wcm90bxom", | |||
| "dGVuc29yZmxvdy9jb3JlL3Byb3RvYnVmL2NsdXN0ZXIucHJvdG8aJHRlbnNv", | |||
| "cmZsb3cvY29yZS9wcm90b2J1Zi9kZWJ1Zy5wcm90bxoudGVuc29yZmxvdy9j", | |||
| "b3JlL3Byb3RvYnVmL3Jld3JpdGVyX2NvbmZpZy5wcm90byK3BQoKR1BVT3B0", | |||
| "aW9ucxInCh9wZXJfcHJvY2Vzc19ncHVfbWVtb3J5X2ZyYWN0aW9uGAEgASgB", | |||
| "EhQKDGFsbG93X2dyb3d0aBgEIAEoCBIWCg5hbGxvY2F0b3JfdHlwZRgCIAEo", | |||
| "CRIfChdkZWZlcnJlZF9kZWxldGlvbl9ieXRlcxgDIAEoAxIbChN2aXNpYmxl", | |||
| @@ -38,89 +38,102 @@ namespace Tensorflow { | |||
| "ZWNzGAYgASgFEiQKHHBvbGxpbmdfaW5hY3RpdmVfZGVsYXlfbXNlY3MYByAB", | |||
| "KAUSHAoUZm9yY2VfZ3B1X2NvbXBhdGlibGUYCCABKAgSOQoMZXhwZXJpbWVu", | |||
| "dGFsGAkgASgLMiMudGVuc29yZmxvdy5HUFVPcHRpb25zLkV4cGVyaW1lbnRh", | |||
| "bBrmAQoMRXhwZXJpbWVudGFsEksKD3ZpcnR1YWxfZGV2aWNlcxgBIAMoCzIy", | |||
| "bBrwAgoMRXhwZXJpbWVudGFsEksKD3ZpcnR1YWxfZGV2aWNlcxgBIAMoCzIy", | |||
| "LnRlbnNvcmZsb3cuR1BVT3B0aW9ucy5FeHBlcmltZW50YWwuVmlydHVhbERl", | |||
| "dmljZXMSGgoSdXNlX3VuaWZpZWRfbWVtb3J5GAIgASgIEiMKG251bV9kZXZf", | |||
| "dG9fZGV2X2NvcHlfc3RyZWFtcxgDIAEoBRIdChVjb2xsZWN0aXZlX3Jpbmdf", | |||
| "b3JkZXIYBCABKAkaKQoOVmlydHVhbERldmljZXMSFwoPbWVtb3J5X2xpbWl0", | |||
| "X21iGAEgAygCIoUDChBPcHRpbWl6ZXJPcHRpb25zEisKI2RvX2NvbW1vbl9z", | |||
| "dWJleHByZXNzaW9uX2VsaW1pbmF0aW9uGAEgASgIEhsKE2RvX2NvbnN0YW50", | |||
| "X2ZvbGRpbmcYAiABKAgSJAocbWF4X2ZvbGRlZF9jb25zdGFudF9pbl9ieXRl", | |||
| "cxgGIAEoAxIcChRkb19mdW5jdGlvbl9pbmxpbmluZxgEIAEoCBI1CglvcHRf", | |||
| "bGV2ZWwYAyABKA4yIi50ZW5zb3JmbG93Lk9wdGltaXplck9wdGlvbnMuTGV2", | |||
| "ZWwSRQoQZ2xvYmFsX2ppdF9sZXZlbBgFIAEoDjIrLnRlbnNvcmZsb3cuT3B0", | |||
| "aW1pemVyT3B0aW9ucy5HbG9iYWxKaXRMZXZlbCIgCgVMZXZlbBIGCgJMMRAA", | |||
| "Eg8KAkwwEP///////////wEiQwoOR2xvYmFsSml0TGV2ZWwSCwoHREVGQVVM", | |||
| "VBAAEhAKA09GRhD///////////8BEggKBE9OXzEQARIICgRPTl8yEAIi7gIK", | |||
| "DEdyYXBoT3B0aW9ucxIeChZlbmFibGVfcmVjdl9zY2hlZHVsaW5nGAIgASgI", | |||
| "EjcKEW9wdGltaXplcl9vcHRpb25zGAMgASgLMhwudGVuc29yZmxvdy5PcHRp", | |||
| "bWl6ZXJPcHRpb25zEhgKEGJ1aWxkX2Nvc3RfbW9kZWwYBCABKAMSHgoWYnVp", | |||
| "bGRfY29zdF9tb2RlbF9hZnRlchgJIAEoAxIUCgxpbmZlcl9zaGFwZXMYBSAB", | |||
| "KAgSGgoScGxhY2VfcHJ1bmVkX2dyYXBoGAYgASgIEiAKGGVuYWJsZV9iZmxv", | |||
| "YXQxNl9zZW5kcmVjdhgHIAEoCBIVCg10aW1lbGluZV9zdGVwGAggASgFEjMK", | |||
| "D3Jld3JpdGVfb3B0aW9ucxgKIAEoCzIaLnRlbnNvcmZsb3cuUmV3cml0ZXJD", | |||
| "b25maWdKBAgBEAJSJXNraXBfY29tbW9uX3N1YmV4cHJlc3Npb25fZWxpbWlu", | |||
| "YXRpb24iQQoVVGhyZWFkUG9vbE9wdGlvblByb3RvEhMKC251bV90aHJlYWRz", | |||
| "GAEgASgFEhMKC2dsb2JhbF9uYW1lGAIgASgJImwKClJQQ09wdGlvbnMSJAoc", | |||
| "dXNlX3JwY19mb3JfaW5wcm9jZXNzX21hc3RlchgBIAEoCBIdChVjb21wcmVz", | |||
| "c2lvbl9hbGdvcml0aG0YAiABKAkSGQoRY29tcHJlc3Npb25fbGV2ZWwYAyAB", | |||
| "KAUi3wYKC0NvbmZpZ1Byb3RvEj4KDGRldmljZV9jb3VudBgBIAMoCzIoLnRl", | |||
| "bnNvcmZsb3cuQ29uZmlnUHJvdG8uRGV2aWNlQ291bnRFbnRyeRIkChxpbnRy", | |||
| "YV9vcF9wYXJhbGxlbGlzbV90aHJlYWRzGAIgASgFEiQKHGludGVyX29wX3Bh", | |||
| "cmFsbGVsaXNtX3RocmVhZHMYBSABKAUSHwoXdXNlX3Blcl9zZXNzaW9uX3Ro", | |||
| "cmVhZHMYCSABKAgSRwocc2Vzc2lvbl9pbnRlcl9vcF90aHJlYWRfcG9vbBgM", | |||
| "IAMoCzIhLnRlbnNvcmZsb3cuVGhyZWFkUG9vbE9wdGlvblByb3RvEhgKEHBs", | |||
| "YWNlbWVudF9wZXJpb2QYAyABKAUSFgoOZGV2aWNlX2ZpbHRlcnMYBCADKAkS", | |||
| "KwoLZ3B1X29wdGlvbnMYBiABKAsyFi50ZW5zb3JmbG93LkdQVU9wdGlvbnMS", | |||
| "HAoUYWxsb3dfc29mdF9wbGFjZW1lbnQYByABKAgSHAoUbG9nX2RldmljZV9w", | |||
| "bGFjZW1lbnQYCCABKAgSLwoNZ3JhcGhfb3B0aW9ucxgKIAEoCzIYLnRlbnNv", | |||
| "cmZsb3cuR3JhcGhPcHRpb25zEh8KF29wZXJhdGlvbl90aW1lb3V0X2luX21z", | |||
| "GAsgASgDEisKC3JwY19vcHRpb25zGA0gASgLMhYudGVuc29yZmxvdy5SUENP", | |||
| "cHRpb25zEisKC2NsdXN0ZXJfZGVmGA4gASgLMhYudGVuc29yZmxvdy5DbHVz", | |||
| "dGVyRGVmEh0KFWlzb2xhdGVfc2Vzc2lvbl9zdGF0ZRgPIAEoCBI6CgxleHBl", | |||
| "cmltZW50YWwYECABKAsyJC50ZW5zb3JmbG93LkNvbmZpZ1Byb3RvLkV4cGVy", | |||
| "aW1lbnRhbBoyChBEZXZpY2VDb3VudEVudHJ5EgsKA2tleRgBIAEoCRINCgV2", | |||
| "YWx1ZRgCIAEoBToCOAEagwEKDEV4cGVyaW1lbnRhbBIfChdjb2xsZWN0aXZl", | |||
| "X2dyb3VwX2xlYWRlchgBIAEoCRIVCg1leGVjdXRvcl90eXBlGAMgASgJEhoK", | |||
| "EnJlY3ZfYnVmX21heF9jaHVuaxgEIAEoBRIZChF1c2VfbnVtYV9hZmZpbml0", | |||
| "eRgFIAEoCEoECAIQAyLYAwoKUnVuT3B0aW9ucxI2Cgt0cmFjZV9sZXZlbBgB", | |||
| "IAEoDjIhLnRlbnNvcmZsb3cuUnVuT3B0aW9ucy5UcmFjZUxldmVsEhUKDXRp", | |||
| "bWVvdXRfaW5fbXMYAiABKAMSHAoUaW50ZXJfb3BfdGhyZWFkX3Bvb2wYAyAB", | |||
| "KAUSHwoXb3V0cHV0X3BhcnRpdGlvbl9ncmFwaHMYBSABKAgSLwoNZGVidWdf", | |||
| "b3B0aW9ucxgGIAEoCzIYLnRlbnNvcmZsb3cuRGVidWdPcHRpb25zEioKInJl", | |||
| "cG9ydF90ZW5zb3JfYWxsb2NhdGlvbnNfdXBvbl9vb20YByABKAgSOQoMZXhw", | |||
| "ZXJpbWVudGFsGAggASgLMiMudGVuc29yZmxvdy5SdW5PcHRpb25zLkV4cGVy", | |||
| "aW1lbnRhbBpKCgxFeHBlcmltZW50YWwSHAoUY29sbGVjdGl2ZV9ncmFwaF9r", | |||
| "ZXkYASABKAMSHAoUdXNlX3J1bl9oYW5kbGVyX3Bvb2wYAiABKAgiUgoKVHJh", | |||
| "Y2VMZXZlbBIMCghOT19UUkFDRRAAEhIKDlNPRlRXQVJFX1RSQUNFEAESEgoO", | |||
| "SEFSRFdBUkVfVFJBQ0UQAhIOCgpGVUxMX1RSQUNFEANKBAgEEAUilgEKC1J1", | |||
| "bk1ldGFkYXRhEikKCnN0ZXBfc3RhdHMYASABKAsyFS50ZW5zb3JmbG93LlN0", | |||
| "ZXBTdGF0cxIsCgpjb3N0X2dyYXBoGAIgASgLMhgudGVuc29yZmxvdy5Db3N0", | |||
| "R3JhcGhEZWYSLgoQcGFydGl0aW9uX2dyYXBocxgDIAMoCzIULnRlbnNvcmZs", | |||
| "b3cuR3JhcGhEZWYiOgoQVGVuc29yQ29ubmVjdGlvbhITCgtmcm9tX3RlbnNv", | |||
| "chgBIAEoCRIRCgl0b190ZW5zb3IYAiABKAkisAMKD0NhbGxhYmxlT3B0aW9u", | |||
| "cxIMCgRmZWVkGAEgAygJEg0KBWZldGNoGAIgAygJEg4KBnRhcmdldBgDIAMo", | |||
| "CRIrCgtydW5fb3B0aW9ucxgEIAEoCzIWLnRlbnNvcmZsb3cuUnVuT3B0aW9u", | |||
| "cxI3ChF0ZW5zb3JfY29ubmVjdGlvbhgFIAMoCzIcLnRlbnNvcmZsb3cuVGVu", | |||
| "c29yQ29ubmVjdGlvbhJCCgxmZWVkX2RldmljZXMYBiADKAsyLC50ZW5zb3Jm", | |||
| "bG93LkNhbGxhYmxlT3B0aW9ucy5GZWVkRGV2aWNlc0VudHJ5EkQKDWZldGNo", | |||
| "X2RldmljZXMYByADKAsyLS50ZW5zb3JmbG93LkNhbGxhYmxlT3B0aW9ucy5G", | |||
| "ZXRjaERldmljZXNFbnRyeRIXCg9mZXRjaF9za2lwX3N5bmMYCCABKAgaMgoQ", | |||
| "RmVlZERldmljZXNFbnRyeRILCgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAk6", | |||
| "AjgBGjMKEUZldGNoRGV2aWNlc0VudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1", | |||
| "ZRgCIAEoCToCOAFCLQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQgxDb25m", | |||
| "aWdQcm90b3NQAfgBAWIGcHJvdG8z")); | |||
| "b3JkZXIYBCABKAkSHQoVdGltZXN0YW1wZWRfYWxsb2NhdG9yGAUgASgIEiMK", | |||
| "G2tlcm5lbF90cmFja2VyX21heF9pbnRlcnZhbBgHIAEoBRIgChhrZXJuZWxf", | |||
| "dHJhY2tlcl9tYXhfYnl0ZXMYCCABKAUSIgoaa2VybmVsX3RyYWNrZXJfbWF4", | |||
| "X3BlbmRpbmcYCSABKAUaKQoOVmlydHVhbERldmljZXMSFwoPbWVtb3J5X2xp", | |||
| "bWl0X21iGAEgAygCIoUDChBPcHRpbWl6ZXJPcHRpb25zEisKI2RvX2NvbW1v", | |||
| "bl9zdWJleHByZXNzaW9uX2VsaW1pbmF0aW9uGAEgASgIEhsKE2RvX2NvbnN0", | |||
| "YW50X2ZvbGRpbmcYAiABKAgSJAocbWF4X2ZvbGRlZF9jb25zdGFudF9pbl9i", | |||
| "eXRlcxgGIAEoAxIcChRkb19mdW5jdGlvbl9pbmxpbmluZxgEIAEoCBI1Cglv", | |||
| "cHRfbGV2ZWwYAyABKA4yIi50ZW5zb3JmbG93Lk9wdGltaXplck9wdGlvbnMu", | |||
| "TGV2ZWwSRQoQZ2xvYmFsX2ppdF9sZXZlbBgFIAEoDjIrLnRlbnNvcmZsb3cu", | |||
| "T3B0aW1pemVyT3B0aW9ucy5HbG9iYWxKaXRMZXZlbCIgCgVMZXZlbBIGCgJM", | |||
| "MRAAEg8KAkwwEP///////////wEiQwoOR2xvYmFsSml0TGV2ZWwSCwoHREVG", | |||
| "QVVMVBAAEhAKA09GRhD///////////8BEggKBE9OXzEQARIICgRPTl8yEAIi", | |||
| "7gIKDEdyYXBoT3B0aW9ucxIeChZlbmFibGVfcmVjdl9zY2hlZHVsaW5nGAIg", | |||
| "ASgIEjcKEW9wdGltaXplcl9vcHRpb25zGAMgASgLMhwudGVuc29yZmxvdy5P", | |||
| "cHRpbWl6ZXJPcHRpb25zEhgKEGJ1aWxkX2Nvc3RfbW9kZWwYBCABKAMSHgoW", | |||
| "YnVpbGRfY29zdF9tb2RlbF9hZnRlchgJIAEoAxIUCgxpbmZlcl9zaGFwZXMY", | |||
| "BSABKAgSGgoScGxhY2VfcHJ1bmVkX2dyYXBoGAYgASgIEiAKGGVuYWJsZV9i", | |||
| "ZmxvYXQxNl9zZW5kcmVjdhgHIAEoCBIVCg10aW1lbGluZV9zdGVwGAggASgF", | |||
| "EjMKD3Jld3JpdGVfb3B0aW9ucxgKIAEoCzIaLnRlbnNvcmZsb3cuUmV3cml0", | |||
| "ZXJDb25maWdKBAgBEAJSJXNraXBfY29tbW9uX3N1YmV4cHJlc3Npb25fZWxp", | |||
| "bWluYXRpb24iQQoVVGhyZWFkUG9vbE9wdGlvblByb3RvEhMKC251bV90aHJl", | |||
| "YWRzGAEgASgFEhMKC2dsb2JhbF9uYW1lGAIgASgJImwKClJQQ09wdGlvbnMS", | |||
| "JAocdXNlX3JwY19mb3JfaW5wcm9jZXNzX21hc3RlchgBIAEoCBIdChVjb21w", | |||
| "cmVzc2lvbl9hbGdvcml0aG0YAiABKAkSGQoRY29tcHJlc3Npb25fbGV2ZWwY", | |||
| "AyABKAUisggKC0NvbmZpZ1Byb3RvEj4KDGRldmljZV9jb3VudBgBIAMoCzIo", | |||
| "LnRlbnNvcmZsb3cuQ29uZmlnUHJvdG8uRGV2aWNlQ291bnRFbnRyeRIkChxp", | |||
| "bnRyYV9vcF9wYXJhbGxlbGlzbV90aHJlYWRzGAIgASgFEiQKHGludGVyX29w", | |||
| "X3BhcmFsbGVsaXNtX3RocmVhZHMYBSABKAUSHwoXdXNlX3Blcl9zZXNzaW9u", | |||
| "X3RocmVhZHMYCSABKAgSRwocc2Vzc2lvbl9pbnRlcl9vcF90aHJlYWRfcG9v", | |||
| "bBgMIAMoCzIhLnRlbnNvcmZsb3cuVGhyZWFkUG9vbE9wdGlvblByb3RvEhgK", | |||
| "EHBsYWNlbWVudF9wZXJpb2QYAyABKAUSFgoOZGV2aWNlX2ZpbHRlcnMYBCAD", | |||
| "KAkSKwoLZ3B1X29wdGlvbnMYBiABKAsyFi50ZW5zb3JmbG93LkdQVU9wdGlv", | |||
| "bnMSHAoUYWxsb3dfc29mdF9wbGFjZW1lbnQYByABKAgSHAoUbG9nX2Rldmlj", | |||
| "ZV9wbGFjZW1lbnQYCCABKAgSLwoNZ3JhcGhfb3B0aW9ucxgKIAEoCzIYLnRl", | |||
| "bnNvcmZsb3cuR3JhcGhPcHRpb25zEh8KF29wZXJhdGlvbl90aW1lb3V0X2lu", | |||
| "X21zGAsgASgDEisKC3JwY19vcHRpb25zGA0gASgLMhYudGVuc29yZmxvdy5S", | |||
| "UENPcHRpb25zEisKC2NsdXN0ZXJfZGVmGA4gASgLMhYudGVuc29yZmxvdy5D", | |||
| "bHVzdGVyRGVmEh0KFWlzb2xhdGVfc2Vzc2lvbl9zdGF0ZRgPIAEoCBI6Cgxl", | |||
| "eHBlcmltZW50YWwYECABKAsyJC50ZW5zb3JmbG93LkNvbmZpZ1Byb3RvLkV4", | |||
| "cGVyaW1lbnRhbBoyChBEZXZpY2VDb3VudEVudHJ5EgsKA2tleRgBIAEoCRIN", | |||
| "CgV2YWx1ZRgCIAEoBToCOAEa1gIKDEV4cGVyaW1lbnRhbBIfChdjb2xsZWN0", | |||
| "aXZlX2dyb3VwX2xlYWRlchgBIAEoCRIVCg1leGVjdXRvcl90eXBlGAMgASgJ", | |||
| "EhoKEnJlY3ZfYnVmX21heF9jaHVuaxgEIAEoBRIZChF1c2VfbnVtYV9hZmZp", | |||
| "bml0eRgFIAEoCBI1Ci1jb2xsZWN0aXZlX2RldGVybWluaXN0aWNfc2VxdWVu", | |||
| "dGlhbF9leGVjdXRpb24YBiABKAgSFwoPY29sbGVjdGl2ZV9uY2NsGAcgASgI", | |||
| "EjYKLnNoYXJlX3Nlc3Npb25fc3RhdGVfaW5fY2x1c3RlcnNwZWNfcHJvcGFn", | |||
| "YXRpb24YCCABKAgSHwoXZGlzYWJsZV90aHJlYWRfc3Bpbm5pbmcYCSABKAgS", | |||
| "KAogc2hhcmVfY2x1c3Rlcl9kZXZpY2VzX2luX3Nlc3Npb24YCiABKAhKBAgC", | |||
| "EAMi2AMKClJ1bk9wdGlvbnMSNgoLdHJhY2VfbGV2ZWwYASABKA4yIS50ZW5z", | |||
| "b3JmbG93LlJ1bk9wdGlvbnMuVHJhY2VMZXZlbBIVCg10aW1lb3V0X2luX21z", | |||
| "GAIgASgDEhwKFGludGVyX29wX3RocmVhZF9wb29sGAMgASgFEh8KF291dHB1", | |||
| "dF9wYXJ0aXRpb25fZ3JhcGhzGAUgASgIEi8KDWRlYnVnX29wdGlvbnMYBiAB", | |||
| "KAsyGC50ZW5zb3JmbG93LkRlYnVnT3B0aW9ucxIqCiJyZXBvcnRfdGVuc29y", | |||
| "X2FsbG9jYXRpb25zX3Vwb25fb29tGAcgASgIEjkKDGV4cGVyaW1lbnRhbBgI", | |||
| "IAEoCzIjLnRlbnNvcmZsb3cuUnVuT3B0aW9ucy5FeHBlcmltZW50YWwaSgoM", | |||
| "RXhwZXJpbWVudGFsEhwKFGNvbGxlY3RpdmVfZ3JhcGhfa2V5GAEgASgDEhwK", | |||
| "FHVzZV9ydW5faGFuZGxlcl9wb29sGAIgASgIIlIKClRyYWNlTGV2ZWwSDAoI", | |||
| "Tk9fVFJBQ0UQABISCg5TT0ZUV0FSRV9UUkFDRRABEhIKDkhBUkRXQVJFX1RS", | |||
| "QUNFEAISDgoKRlVMTF9UUkFDRRADSgQIBBAFIocDCgtSdW5NZXRhZGF0YRIp", | |||
| "CgpzdGVwX3N0YXRzGAEgASgLMhUudGVuc29yZmxvdy5TdGVwU3RhdHMSLAoK", | |||
| "Y29zdF9ncmFwaBgCIAEoCzIYLnRlbnNvcmZsb3cuQ29zdEdyYXBoRGVmEi4K", | |||
| "EHBhcnRpdGlvbl9ncmFwaHMYAyADKAsyFC50ZW5zb3JmbG93LkdyYXBoRGVm", | |||
| "Ej8KD2Z1bmN0aW9uX2dyYXBocxgEIAMoCzImLnRlbnNvcmZsb3cuUnVuTWV0", | |||
| "YWRhdGEuRnVuY3Rpb25HcmFwaHMarQEKDkZ1bmN0aW9uR3JhcGhzEi4KEHBh", | |||
| "cnRpdGlvbl9ncmFwaHMYASADKAsyFC50ZW5zb3JmbG93LkdyYXBoRGVmEjQK", | |||
| "FnByZV9vcHRpbWl6YXRpb25fZ3JhcGgYAiABKAsyFC50ZW5zb3JmbG93Lkdy", | |||
| "YXBoRGVmEjUKF3Bvc3Rfb3B0aW1pemF0aW9uX2dyYXBoGAMgASgLMhQudGVu", | |||
| "c29yZmxvdy5HcmFwaERlZiI6ChBUZW5zb3JDb25uZWN0aW9uEhMKC2Zyb21f", | |||
| "dGVuc29yGAEgASgJEhEKCXRvX3RlbnNvchgCIAEoCSKwAwoPQ2FsbGFibGVP", | |||
| "cHRpb25zEgwKBGZlZWQYASADKAkSDQoFZmV0Y2gYAiADKAkSDgoGdGFyZ2V0", | |||
| "GAMgAygJEisKC3J1bl9vcHRpb25zGAQgASgLMhYudGVuc29yZmxvdy5SdW5P", | |||
| "cHRpb25zEjcKEXRlbnNvcl9jb25uZWN0aW9uGAUgAygLMhwudGVuc29yZmxv", | |||
| "dy5UZW5zb3JDb25uZWN0aW9uEkIKDGZlZWRfZGV2aWNlcxgGIAMoCzIsLnRl", | |||
| "bnNvcmZsb3cuQ2FsbGFibGVPcHRpb25zLkZlZWREZXZpY2VzRW50cnkSRAoN", | |||
| "ZmV0Y2hfZGV2aWNlcxgHIAMoCzItLnRlbnNvcmZsb3cuQ2FsbGFibGVPcHRp", | |||
| "b25zLkZldGNoRGV2aWNlc0VudHJ5EhcKD2ZldGNoX3NraXBfc3luYxgIIAEo", | |||
| "CBoyChBGZWVkRGV2aWNlc0VudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgC", | |||
| "IAEoCToCOAEaMwoRRmV0Y2hEZXZpY2VzRW50cnkSCwoDa2V5GAEgASgJEg0K", | |||
| "BXZhbHVlGAIgASgJOgI4AUItChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtC", | |||
| "DENvbmZpZ1Byb3Rvc1AB+AEBYgZwcm90bzM=")); | |||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
| new pbr::FileDescriptor[] { global::Tensorflow.CostGraphReflection.Descriptor, global::Tensorflow.GraphReflection.Descriptor, global::Tensorflow.StepStatsReflection.Descriptor, global::Tensorflow.DebugReflection.Descriptor, global::Tensorflow.ClusterReflection.Descriptor, global::Tensorflow.RewriterConfigReflection.Descriptor, }, | |||
| new pbr::FileDescriptor[] { global::Tensorflow.CostGraphReflection.Descriptor, global::Tensorflow.GraphReflection.Descriptor, global::Tensorflow.StepStatsReflection.Descriptor, global::Tensorflow.ClusterReflection.Descriptor, global::Tensorflow.DebugReflection.Descriptor, global::Tensorflow.RewriterConfigReflection.Descriptor, }, | |||
| new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions), global::Tensorflow.GPUOptions.Parser, new[]{ "PerProcessGpuMemoryFraction", "AllowGrowth", "AllocatorType", "DeferredDeletionBytes", "VisibleDeviceList", "PollingActiveDelayUsecs", "PollingInactiveDelayMsecs", "ForceGpuCompatible", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental), global::Tensorflow.GPUOptions.Types.Experimental.Parser, new[]{ "VirtualDevices", "UseUnifiedMemory", "NumDevToDevCopyStreams", "CollectiveRingOrder" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices), global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices.Parser, new[]{ "MemoryLimitMb" }, null, null, null)})}), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions), global::Tensorflow.GPUOptions.Parser, new[]{ "PerProcessGpuMemoryFraction", "AllowGrowth", "AllocatorType", "DeferredDeletionBytes", "VisibleDeviceList", "PollingActiveDelayUsecs", "PollingInactiveDelayMsecs", "ForceGpuCompatible", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental), global::Tensorflow.GPUOptions.Types.Experimental.Parser, new[]{ "VirtualDevices", "UseUnifiedMemory", "NumDevToDevCopyStreams", "CollectiveRingOrder", "TimestampedAllocator", "KernelTrackerMaxInterval", "KernelTrackerMaxBytes", "KernelTrackerMaxPending" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices), global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices.Parser, new[]{ "MemoryLimitMb" }, null, null, null)})}), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OptimizerOptions), global::Tensorflow.OptimizerOptions.Parser, new[]{ "DoCommonSubexpressionElimination", "DoConstantFolding", "MaxFoldedConstantInBytes", "DoFunctionInlining", "OptLevel", "GlobalJitLevel" }, null, new[]{ typeof(global::Tensorflow.OptimizerOptions.Types.Level), typeof(global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel) }, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphOptions), global::Tensorflow.GraphOptions.Parser, new[]{ "EnableRecvScheduling", "OptimizerOptions", "BuildCostModel", "BuildCostModelAfter", "InferShapes", "PlacePrunedGraph", "EnableBfloat16Sendrecv", "TimelineStep", "RewriteOptions" }, null, null, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ThreadPoolOptionProto), global::Tensorflow.ThreadPoolOptionProto.Parser, new[]{ "NumThreads", "GlobalName" }, null, null, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RPCOptions), global::Tensorflow.RPCOptions.Parser, new[]{ "UseRpcForInprocessMaster", "CompressionAlgorithm", "CompressionLevel" }, null, null, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto), global::Tensorflow.ConfigProto.Parser, new[]{ "DeviceCount", "IntraOpParallelismThreads", "InterOpParallelismThreads", "UsePerSessionThreads", "SessionInterOpThreadPool", "PlacementPeriod", "DeviceFilters", "GpuOptions", "AllowSoftPlacement", "LogDevicePlacement", "GraphOptions", "OperationTimeoutInMs", "RpcOptions", "ClusterDef", "IsolateSessionState", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto.Types.Experimental), global::Tensorflow.ConfigProto.Types.Experimental.Parser, new[]{ "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity" }, null, null, null)}), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto), global::Tensorflow.ConfigProto.Parser, new[]{ "DeviceCount", "IntraOpParallelismThreads", "InterOpParallelismThreads", "UsePerSessionThreads", "SessionInterOpThreadPool", "PlacementPeriod", "DeviceFilters", "GpuOptions", "AllowSoftPlacement", "LogDevicePlacement", "GraphOptions", "OperationTimeoutInMs", "RpcOptions", "ClusterDef", "IsolateSessionState", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto.Types.Experimental), global::Tensorflow.ConfigProto.Types.Experimental.Parser, new[]{ "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity", "CollectiveDeterministicSequentialExecution", "CollectiveNccl", "ShareSessionStateInClusterspecPropagation", "DisableThreadSpinning", "ShareClusterDevicesInSession" }, null, null, null)}), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions), global::Tensorflow.RunOptions.Parser, new[]{ "TraceLevel", "TimeoutInMs", "InterOpThreadPool", "OutputPartitionGraphs", "DebugOptions", "ReportTensorAllocationsUponOom", "Experimental" }, null, new[]{ typeof(global::Tensorflow.RunOptions.Types.TraceLevel) }, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions.Types.Experimental), global::Tensorflow.RunOptions.Types.Experimental.Parser, new[]{ "CollectiveGraphKey", "UseRunHandlerPool" }, null, null, null)}), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata), global::Tensorflow.RunMetadata.Parser, new[]{ "StepStats", "CostGraph", "PartitionGraphs" }, null, null, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata), global::Tensorflow.RunMetadata.Parser, new[]{ "StepStats", "CostGraph", "PartitionGraphs", "FunctionGraphs" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata.Types.FunctionGraphs), global::Tensorflow.RunMetadata.Types.FunctionGraphs.Parser, new[]{ "PartitionGraphs", "PreOptimizationGraph", "PostOptimizationGraph" }, null, null, null)}), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorConnection), global::Tensorflow.TensorConnection.Parser, new[]{ "FromTensor", "ToTensor" }, null, null, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CallableOptions), global::Tensorflow.CallableOptions.Parser, new[]{ "Feed", "Fetch", "Target", "RunOptions", "TensorConnection", "FeedDevices", "FetchDevices", "FetchSkipSync" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }) | |||
| })); | |||
| @@ -605,6 +618,10 @@ namespace Tensorflow { | |||
| useUnifiedMemory_ = other.useUnifiedMemory_; | |||
| numDevToDevCopyStreams_ = other.numDevToDevCopyStreams_; | |||
| collectiveRingOrder_ = other.collectiveRingOrder_; | |||
| timestampedAllocator_ = other.timestampedAllocator_; | |||
| kernelTrackerMaxInterval_ = other.kernelTrackerMaxInterval_; | |||
| kernelTrackerMaxBytes_ = other.kernelTrackerMaxBytes_; | |||
| kernelTrackerMaxPending_ = other.kernelTrackerMaxPending_; | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| @@ -703,6 +720,77 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| /// <summary>Field number for the "timestamped_allocator" field.</summary> | |||
| public const int TimestampedAllocatorFieldNumber = 5; | |||
| private bool timestampedAllocator_; | |||
| /// <summary> | |||
| /// If true then extra work is done by GPUDevice and GPUBFCAllocator to | |||
| /// keep track of when GPU memory is freed and when kernels actually | |||
| /// complete so that we can know when a nominally free memory chunk | |||
| /// is really not subject to pending use. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool TimestampedAllocator { | |||
| get { return timestampedAllocator_; } | |||
| set { | |||
| timestampedAllocator_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "kernel_tracker_max_interval" field.</summary> | |||
| public const int KernelTrackerMaxIntervalFieldNumber = 7; | |||
| private int kernelTrackerMaxInterval_; | |||
| /// <summary> | |||
| /// Parameters for GPUKernelTracker. By default no kernel tracking is done. | |||
| /// Note that timestamped_allocator is only effective if some tracking is | |||
| /// specified. | |||
| /// | |||
| /// If kernel_tracker_max_interval = n > 0, then a tracking event | |||
| /// is inserted after every n kernels without an event. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int KernelTrackerMaxInterval { | |||
| get { return kernelTrackerMaxInterval_; } | |||
| set { | |||
| kernelTrackerMaxInterval_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "kernel_tracker_max_bytes" field.</summary> | |||
| public const int KernelTrackerMaxBytesFieldNumber = 8; | |||
| private int kernelTrackerMaxBytes_; | |||
| /// <summary> | |||
| /// If kernel_tracker_max_bytes = n > 0, then a tracking event is | |||
| /// inserted after every series of kernels allocating a sum of | |||
| /// memory >= n. If one kernel allocates b * n bytes, then one | |||
| /// event will be inserted after it, but it will count as b against | |||
| /// the pending limit. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int KernelTrackerMaxBytes { | |||
| get { return kernelTrackerMaxBytes_; } | |||
| set { | |||
| kernelTrackerMaxBytes_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "kernel_tracker_max_pending" field.</summary> | |||
| public const int KernelTrackerMaxPendingFieldNumber = 9; | |||
| private int kernelTrackerMaxPending_; | |||
| /// <summary> | |||
| /// If kernel_tracker_max_pending > 0 then no more than this many | |||
| /// tracking events can be outstanding at a time. An attempt to | |||
| /// launch an additional kernel will stall until an event | |||
| /// completes. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int KernelTrackerMaxPending { | |||
| get { return kernelTrackerMaxPending_; } | |||
| set { | |||
| kernelTrackerMaxPending_ = value; | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as Experimental); | |||
| @@ -720,6 +808,10 @@ namespace Tensorflow { | |||
| if (UseUnifiedMemory != other.UseUnifiedMemory) return false; | |||
| if (NumDevToDevCopyStreams != other.NumDevToDevCopyStreams) return false; | |||
| if (CollectiveRingOrder != other.CollectiveRingOrder) return false; | |||
| if (TimestampedAllocator != other.TimestampedAllocator) return false; | |||
| if (KernelTrackerMaxInterval != other.KernelTrackerMaxInterval) return false; | |||
| if (KernelTrackerMaxBytes != other.KernelTrackerMaxBytes) return false; | |||
| if (KernelTrackerMaxPending != other.KernelTrackerMaxPending) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| @@ -730,6 +822,10 @@ namespace Tensorflow { | |||
| if (UseUnifiedMemory != false) hash ^= UseUnifiedMemory.GetHashCode(); | |||
| if (NumDevToDevCopyStreams != 0) hash ^= NumDevToDevCopyStreams.GetHashCode(); | |||
| if (CollectiveRingOrder.Length != 0) hash ^= CollectiveRingOrder.GetHashCode(); | |||
| if (TimestampedAllocator != false) hash ^= TimestampedAllocator.GetHashCode(); | |||
| if (KernelTrackerMaxInterval != 0) hash ^= KernelTrackerMaxInterval.GetHashCode(); | |||
| if (KernelTrackerMaxBytes != 0) hash ^= KernelTrackerMaxBytes.GetHashCode(); | |||
| if (KernelTrackerMaxPending != 0) hash ^= KernelTrackerMaxPending.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| @@ -756,6 +852,22 @@ namespace Tensorflow { | |||
| output.WriteRawTag(34); | |||
| output.WriteString(CollectiveRingOrder); | |||
| } | |||
| if (TimestampedAllocator != false) { | |||
| output.WriteRawTag(40); | |||
| output.WriteBool(TimestampedAllocator); | |||
| } | |||
| if (KernelTrackerMaxInterval != 0) { | |||
| output.WriteRawTag(56); | |||
| output.WriteInt32(KernelTrackerMaxInterval); | |||
| } | |||
| if (KernelTrackerMaxBytes != 0) { | |||
| output.WriteRawTag(64); | |||
| output.WriteInt32(KernelTrackerMaxBytes); | |||
| } | |||
| if (KernelTrackerMaxPending != 0) { | |||
| output.WriteRawTag(72); | |||
| output.WriteInt32(KernelTrackerMaxPending); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| @@ -774,6 +886,18 @@ namespace Tensorflow { | |||
| if (CollectiveRingOrder.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(CollectiveRingOrder); | |||
| } | |||
| if (TimestampedAllocator != false) { | |||
| size += 1 + 1; | |||
| } | |||
| if (KernelTrackerMaxInterval != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxInterval); | |||
| } | |||
| if (KernelTrackerMaxBytes != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxBytes); | |||
| } | |||
| if (KernelTrackerMaxPending != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxPending); | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| @@ -795,6 +919,18 @@ namespace Tensorflow { | |||
| if (other.CollectiveRingOrder.Length != 0) { | |||
| CollectiveRingOrder = other.CollectiveRingOrder; | |||
| } | |||
| if (other.TimestampedAllocator != false) { | |||
| TimestampedAllocator = other.TimestampedAllocator; | |||
| } | |||
| if (other.KernelTrackerMaxInterval != 0) { | |||
| KernelTrackerMaxInterval = other.KernelTrackerMaxInterval; | |||
| } | |||
| if (other.KernelTrackerMaxBytes != 0) { | |||
| KernelTrackerMaxBytes = other.KernelTrackerMaxBytes; | |||
| } | |||
| if (other.KernelTrackerMaxPending != 0) { | |||
| KernelTrackerMaxPending = other.KernelTrackerMaxPending; | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| @@ -822,6 +958,22 @@ namespace Tensorflow { | |||
| CollectiveRingOrder = input.ReadString(); | |||
| break; | |||
| } | |||
| case 40: { | |||
| TimestampedAllocator = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 56: { | |||
| KernelTrackerMaxInterval = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 64: { | |||
| KernelTrackerMaxBytes = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 72: { | |||
| KernelTrackerMaxPending = input.ReadInt32(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -2189,6 +2341,7 @@ namespace Tensorflow { | |||
| /// inter_op_parallelism_threads available in each process. | |||
| /// | |||
| /// 0 means the system picks an appropriate number. | |||
| /// Negative means all operations are performed in caller's thread. | |||
| /// | |||
| /// Note that the first Session created in the process sets the | |||
| /// number of threads for all future sessions unless use_per_session_threads is | |||
| @@ -2397,7 +2550,8 @@ namespace Tensorflow { | |||
| private bool isolateSessionState_; | |||
| /// <summary> | |||
| /// If true, any resources such as Variables used in the session will not be | |||
| /// shared with other sessions. | |||
| /// shared with other sessions. However, when clusterspec propagation is | |||
| /// enabled, this field is ignored and sessions are always isolated. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool IsolateSessionState { | |||
| @@ -2787,6 +2941,11 @@ namespace Tensorflow { | |||
| executorType_ = other.executorType_; | |||
| recvBufMaxChunk_ = other.recvBufMaxChunk_; | |||
| useNumaAffinity_ = other.useNumaAffinity_; | |||
| collectiveDeterministicSequentialExecution_ = other.collectiveDeterministicSequentialExecution_; | |||
| collectiveNccl_ = other.collectiveNccl_; | |||
| shareSessionStateInClusterspecPropagation_ = other.shareSessionStateInClusterspecPropagation_; | |||
| disableThreadSpinning_ = other.disableThreadSpinning_; | |||
| shareClusterDevicesInSession_ = other.shareClusterDevicesInSession_; | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| @@ -2856,6 +3015,103 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| /// <summary>Field number for the "collective_deterministic_sequential_execution" field.</summary> | |||
| public const int CollectiveDeterministicSequentialExecutionFieldNumber = 6; | |||
| private bool collectiveDeterministicSequentialExecution_; | |||
| /// <summary> | |||
| /// If true, make collective op execution order sequential and deterministic | |||
| /// for potentially concurrent collective instances. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool CollectiveDeterministicSequentialExecution { | |||
| get { return collectiveDeterministicSequentialExecution_; } | |||
| set { | |||
| collectiveDeterministicSequentialExecution_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "collective_nccl" field.</summary> | |||
| public const int CollectiveNcclFieldNumber = 7; | |||
| private bool collectiveNccl_; | |||
| /// <summary> | |||
| /// If true, use NCCL for CollectiveOps. This feature is highly | |||
| /// experimental. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool CollectiveNccl { | |||
| get { return collectiveNccl_; } | |||
| set { | |||
| collectiveNccl_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "share_session_state_in_clusterspec_propagation" field.</summary> | |||
| public const int ShareSessionStateInClusterspecPropagationFieldNumber = 8; | |||
| private bool shareSessionStateInClusterspecPropagation_; | |||
| /// <summary> | |||
| /// In the following, session state means the value of a variable, elements | |||
| /// in a hash table, or any other resource, accessible by worker sessions | |||
| /// held by a TF server. | |||
| /// | |||
| /// When ClusterSpec propagation is enabled, the value of | |||
| /// isolate_session_state is ignored when deciding whether to share session | |||
| /// states in a TF server (for backwards compatibility reasons). | |||
| /// - If share_session_state_in_clusterspec_propagation is true, the session | |||
| /// states are shared. | |||
| /// - If share_session_state_in_clusterspec_propagation is false, session | |||
| /// states are isolated. | |||
| /// | |||
| /// When clusterspec propagation is not used, the value of | |||
| /// share_session_state_in_clusterspec_propagation is ignored when deciding | |||
| /// whether to share session states in a TF server. | |||
| /// - If isolate_session_state is true, session states are isolated. | |||
| /// - If isolate_session_state is false, session states are shared. | |||
| /// | |||
| /// TODO(b/129330037): Add a single API that consistently treats | |||
| /// isolate_session_state and ClusterSpec propagation. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool ShareSessionStateInClusterspecPropagation { | |||
| get { return shareSessionStateInClusterspecPropagation_; } | |||
| set { | |||
| shareSessionStateInClusterspecPropagation_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "disable_thread_spinning" field.</summary> | |||
| public const int DisableThreadSpinningFieldNumber = 9; | |||
| private bool disableThreadSpinning_; | |||
| /// <summary> | |||
| /// If using a direct session, disable spinning while waiting for work in | |||
| /// the thread pool. This may result in higher latency for completing ops, | |||
| /// but in the case where there is a lot of spinning may result in lower | |||
| /// CPU usage. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool DisableThreadSpinning { | |||
| get { return disableThreadSpinning_; } | |||
| set { | |||
| disableThreadSpinning_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "share_cluster_devices_in_session" field.</summary> | |||
| public const int ShareClusterDevicesInSessionFieldNumber = 10; | |||
| private bool shareClusterDevicesInSession_; | |||
| /// <summary> | |||
| /// When true, WorkerSessions are created with device attributes from the | |||
| /// full cluster. | |||
| /// This is helpful when a worker wants to partition a graph | |||
| /// (for example during a PartitionedCallOp). | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool ShareClusterDevicesInSession { | |||
| get { return shareClusterDevicesInSession_; } | |||
| set { | |||
| shareClusterDevicesInSession_ = value; | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as Experimental); | |||
| @@ -2873,6 +3129,11 @@ namespace Tensorflow { | |||
| if (ExecutorType != other.ExecutorType) return false; | |||
| if (RecvBufMaxChunk != other.RecvBufMaxChunk) return false; | |||
| if (UseNumaAffinity != other.UseNumaAffinity) return false; | |||
| if (CollectiveDeterministicSequentialExecution != other.CollectiveDeterministicSequentialExecution) return false; | |||
| if (CollectiveNccl != other.CollectiveNccl) return false; | |||
| if (ShareSessionStateInClusterspecPropagation != other.ShareSessionStateInClusterspecPropagation) return false; | |||
| if (DisableThreadSpinning != other.DisableThreadSpinning) return false; | |||
| if (ShareClusterDevicesInSession != other.ShareClusterDevicesInSession) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| @@ -2883,6 +3144,11 @@ namespace Tensorflow { | |||
| if (ExecutorType.Length != 0) hash ^= ExecutorType.GetHashCode(); | |||
| if (RecvBufMaxChunk != 0) hash ^= RecvBufMaxChunk.GetHashCode(); | |||
| if (UseNumaAffinity != false) hash ^= UseNumaAffinity.GetHashCode(); | |||
| if (CollectiveDeterministicSequentialExecution != false) hash ^= CollectiveDeterministicSequentialExecution.GetHashCode(); | |||
| if (CollectiveNccl != false) hash ^= CollectiveNccl.GetHashCode(); | |||
| if (ShareSessionStateInClusterspecPropagation != false) hash ^= ShareSessionStateInClusterspecPropagation.GetHashCode(); | |||
| if (DisableThreadSpinning != false) hash ^= DisableThreadSpinning.GetHashCode(); | |||
| if (ShareClusterDevicesInSession != false) hash ^= ShareClusterDevicesInSession.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| @@ -2912,6 +3178,26 @@ namespace Tensorflow { | |||
| output.WriteRawTag(40); | |||
| output.WriteBool(UseNumaAffinity); | |||
| } | |||
| if (CollectiveDeterministicSequentialExecution != false) { | |||
| output.WriteRawTag(48); | |||
| output.WriteBool(CollectiveDeterministicSequentialExecution); | |||
| } | |||
| if (CollectiveNccl != false) { | |||
| output.WriteRawTag(56); | |||
| output.WriteBool(CollectiveNccl); | |||
| } | |||
| if (ShareSessionStateInClusterspecPropagation != false) { | |||
| output.WriteRawTag(64); | |||
| output.WriteBool(ShareSessionStateInClusterspecPropagation); | |||
| } | |||
| if (DisableThreadSpinning != false) { | |||
| output.WriteRawTag(72); | |||
| output.WriteBool(DisableThreadSpinning); | |||
| } | |||
| if (ShareClusterDevicesInSession != false) { | |||
| output.WriteRawTag(80); | |||
| output.WriteBool(ShareClusterDevicesInSession); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| @@ -2932,6 +3218,21 @@ namespace Tensorflow { | |||
| if (UseNumaAffinity != false) { | |||
| size += 1 + 1; | |||
| } | |||
| if (CollectiveDeterministicSequentialExecution != false) { | |||
| size += 1 + 1; | |||
| } | |||
| if (CollectiveNccl != false) { | |||
| size += 1 + 1; | |||
| } | |||
| if (ShareSessionStateInClusterspecPropagation != false) { | |||
| size += 1 + 1; | |||
| } | |||
| if (DisableThreadSpinning != false) { | |||
| size += 1 + 1; | |||
| } | |||
| if (ShareClusterDevicesInSession != false) { | |||
| size += 1 + 1; | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| @@ -2955,6 +3256,21 @@ namespace Tensorflow { | |||
| if (other.UseNumaAffinity != false) { | |||
| UseNumaAffinity = other.UseNumaAffinity; | |||
| } | |||
| if (other.CollectiveDeterministicSequentialExecution != false) { | |||
| CollectiveDeterministicSequentialExecution = other.CollectiveDeterministicSequentialExecution; | |||
| } | |||
| if (other.CollectiveNccl != false) { | |||
| CollectiveNccl = other.CollectiveNccl; | |||
| } | |||
| if (other.ShareSessionStateInClusterspecPropagation != false) { | |||
| ShareSessionStateInClusterspecPropagation = other.ShareSessionStateInClusterspecPropagation; | |||
| } | |||
| if (other.DisableThreadSpinning != false) { | |||
| DisableThreadSpinning = other.DisableThreadSpinning; | |||
| } | |||
| if (other.ShareClusterDevicesInSession != false) { | |||
| ShareClusterDevicesInSession = other.ShareClusterDevicesInSession; | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| @@ -2982,6 +3298,26 @@ namespace Tensorflow { | |||
| UseNumaAffinity = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 48: { | |||
| CollectiveDeterministicSequentialExecution = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 56: { | |||
| CollectiveNccl = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 64: { | |||
| ShareSessionStateInClusterspecPropagation = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 72: { | |||
| DisableThreadSpinning = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 80: { | |||
| ShareClusterDevicesInSession = input.ReadBool(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -3553,6 +3889,7 @@ namespace Tensorflow { | |||
| stepStats_ = other.stepStats_ != null ? other.stepStats_.Clone() : null; | |||
| costGraph_ = other.costGraph_ != null ? other.costGraph_.Clone() : null; | |||
| partitionGraphs_ = other.partitionGraphs_.Clone(); | |||
| functionGraphs_ = other.functionGraphs_.Clone(); | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| @@ -3604,6 +3941,28 @@ namespace Tensorflow { | |||
| get { return partitionGraphs_; } | |||
| } | |||
| /// <summary>Field number for the "function_graphs" field.</summary> | |||
| public const int FunctionGraphsFieldNumber = 4; | |||
| private static readonly pb::FieldCodec<global::Tensorflow.RunMetadata.Types.FunctionGraphs> _repeated_functionGraphs_codec | |||
| = pb::FieldCodec.ForMessage(34, global::Tensorflow.RunMetadata.Types.FunctionGraphs.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.RunMetadata.Types.FunctionGraphs> functionGraphs_ = new pbc::RepeatedField<global::Tensorflow.RunMetadata.Types.FunctionGraphs>(); | |||
| /// <summary> | |||
| /// This is only populated for graphs that are run as functions in TensorFlow | |||
| /// V2. There will be an entry below for each function that is traced. | |||
| /// The main use cases of the post_optimization_graph and the partition_graphs | |||
| /// is to give the caller insight into the graphs that were actually run by the | |||
| /// runtime. Additional information (such as those in step_stats) will match | |||
| /// these graphs. | |||
| /// We also include the pre_optimization_graph since it is usually easier to | |||
| /// read, and is helpful in situations where the caller wants to get a high | |||
| /// level idea of what the built graph looks like (since the various graph | |||
| /// optimization passes might change the structure of the graph significantly). | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<global::Tensorflow.RunMetadata.Types.FunctionGraphs> FunctionGraphs { | |||
| get { return functionGraphs_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as RunMetadata); | |||
| @@ -3620,6 +3979,7 @@ namespace Tensorflow { | |||
| if (!object.Equals(StepStats, other.StepStats)) return false; | |||
| if (!object.Equals(CostGraph, other.CostGraph)) return false; | |||
| if(!partitionGraphs_.Equals(other.partitionGraphs_)) return false; | |||
| if(!functionGraphs_.Equals(other.functionGraphs_)) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| @@ -3629,6 +3989,7 @@ namespace Tensorflow { | |||
| if (stepStats_ != null) hash ^= StepStats.GetHashCode(); | |||
| if (costGraph_ != null) hash ^= CostGraph.GetHashCode(); | |||
| hash ^= partitionGraphs_.GetHashCode(); | |||
| hash ^= functionGraphs_.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| @@ -3651,6 +4012,7 @@ namespace Tensorflow { | |||
| output.WriteMessage(CostGraph); | |||
| } | |||
| partitionGraphs_.WriteTo(output, _repeated_partitionGraphs_codec); | |||
| functionGraphs_.WriteTo(output, _repeated_functionGraphs_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| @@ -3666,6 +4028,7 @@ namespace Tensorflow { | |||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(CostGraph); | |||
| } | |||
| size += partitionGraphs_.CalculateSize(_repeated_partitionGraphs_codec); | |||
| size += functionGraphs_.CalculateSize(_repeated_functionGraphs_codec); | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| @@ -3690,6 +4053,7 @@ namespace Tensorflow { | |||
| CostGraph.MergeFrom(other.CostGraph); | |||
| } | |||
| partitionGraphs_.Add(other.partitionGraphs_); | |||
| functionGraphs_.Add(other.functionGraphs_); | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| @@ -3719,9 +4083,212 @@ namespace Tensorflow { | |||
| partitionGraphs_.AddEntriesFrom(input, _repeated_partitionGraphs_codec); | |||
| break; | |||
| } | |||
| case 34: { | |||
| functionGraphs_.AddEntriesFrom(input, _repeated_functionGraphs_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #region Nested types | |||
| /// <summary>Container for nested types declared in the RunMetadata message type.</summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static partial class Types { | |||
| public sealed partial class FunctionGraphs : pb::IMessage<FunctionGraphs> { | |||
| private static readonly pb::MessageParser<FunctionGraphs> _parser = new pb::MessageParser<FunctionGraphs>(() => new FunctionGraphs()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pb::MessageParser<FunctionGraphs> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.RunMetadata.Descriptor.NestedTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public FunctionGraphs() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public FunctionGraphs(FunctionGraphs other) : this() { | |||
| partitionGraphs_ = other.partitionGraphs_.Clone(); | |||
| preOptimizationGraph_ = other.preOptimizationGraph_ != null ? other.preOptimizationGraph_.Clone() : null; | |||
| postOptimizationGraph_ = other.postOptimizationGraph_ != null ? other.postOptimizationGraph_.Clone() : null; | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public FunctionGraphs Clone() { | |||
| return new FunctionGraphs(this); | |||
| } | |||
| /// <summary>Field number for the "partition_graphs" field.</summary> | |||
| public const int PartitionGraphsFieldNumber = 1; | |||
| private static readonly pb::FieldCodec<global::Tensorflow.GraphDef> _repeated_partitionGraphs_codec | |||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.GraphDef.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.GraphDef> partitionGraphs_ = new pbc::RepeatedField<global::Tensorflow.GraphDef>(); | |||
| /// <summary> | |||
| /// TODO(nareshmodi): Include some sort of function/cache-key identifier? | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<global::Tensorflow.GraphDef> PartitionGraphs { | |||
| get { return partitionGraphs_; } | |||
| } | |||
| /// <summary>Field number for the "pre_optimization_graph" field.</summary> | |||
| public const int PreOptimizationGraphFieldNumber = 2; | |||
| private global::Tensorflow.GraphDef preOptimizationGraph_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public global::Tensorflow.GraphDef PreOptimizationGraph { | |||
| get { return preOptimizationGraph_; } | |||
| set { | |||
| preOptimizationGraph_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "post_optimization_graph" field.</summary> | |||
| public const int PostOptimizationGraphFieldNumber = 3; | |||
| private global::Tensorflow.GraphDef postOptimizationGraph_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public global::Tensorflow.GraphDef PostOptimizationGraph { | |||
| get { return postOptimizationGraph_; } | |||
| set { | |||
| postOptimizationGraph_ = value; | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as FunctionGraphs); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool Equals(FunctionGraphs other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if(!partitionGraphs_.Equals(other.partitionGraphs_)) return false; | |||
| if (!object.Equals(PreOptimizationGraph, other.PreOptimizationGraph)) return false; | |||
| if (!object.Equals(PostOptimizationGraph, other.PostOptimizationGraph)) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= partitionGraphs_.GetHashCode(); | |||
| if (preOptimizationGraph_ != null) hash ^= PreOptimizationGraph.GetHashCode(); | |||
| if (postOptimizationGraph_ != null) hash ^= PostOptimizationGraph.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| partitionGraphs_.WriteTo(output, _repeated_partitionGraphs_codec); | |||
| if (preOptimizationGraph_ != null) { | |||
| output.WriteRawTag(18); | |||
| output.WriteMessage(PreOptimizationGraph); | |||
| } | |||
| if (postOptimizationGraph_ != null) { | |||
| output.WriteRawTag(26); | |||
| output.WriteMessage(PostOptimizationGraph); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += partitionGraphs_.CalculateSize(_repeated_partitionGraphs_codec); | |||
| if (preOptimizationGraph_ != null) { | |||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(PreOptimizationGraph); | |||
| } | |||
| if (postOptimizationGraph_ != null) { | |||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(PostOptimizationGraph); | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(FunctionGraphs other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| partitionGraphs_.Add(other.partitionGraphs_); | |||
| if (other.preOptimizationGraph_ != null) { | |||
| if (preOptimizationGraph_ == null) { | |||
| preOptimizationGraph_ = new global::Tensorflow.GraphDef(); | |||
| } | |||
| PreOptimizationGraph.MergeFrom(other.PreOptimizationGraph); | |||
| } | |||
| if (other.postOptimizationGraph_ != null) { | |||
| if (postOptimizationGraph_ == null) { | |||
| postOptimizationGraph_ = new global::Tensorflow.GraphDef(); | |||
| } | |||
| PostOptimizationGraph.MergeFrom(other.PostOptimizationGraph); | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 10: { | |||
| partitionGraphs_.AddEntriesFrom(input, _repeated_partitionGraphs_codec); | |||
| break; | |||
| } | |||
| case 18: { | |||
| if (preOptimizationGraph_ == null) { | |||
| preOptimizationGraph_ = new global::Tensorflow.GraphDef(); | |||
| } | |||
| input.ReadMessage(preOptimizationGraph_); | |||
| break; | |||
| } | |||
| case 26: { | |||
| if (postOptimizationGraph_ == null) { | |||
| postOptimizationGraph_ = new global::Tensorflow.GraphDef(); | |||
| } | |||
| input.ReadMessage(postOptimizationGraph_); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endregion | |||
| } | |||
| @@ -36,19 +36,20 @@ namespace Tensorflow | |||
| protected byte[] _target; | |||
| public Graph graph => _graph; | |||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null) | |||
| public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | |||
| { | |||
| _graph = g ?? ops.get_default_graph(); | |||
| _graph.as_default(); | |||
| _target = Encoding.UTF8.GetBytes(target); | |||
| SessionOptions lopts = opts ?? new SessionOptions(); | |||
| lock (Locks.ProcessWide) | |||
| using (var opts = new SessionOptions(target, config)) | |||
| { | |||
| status = status ?? new Status(); | |||
| _handle = c_api.TF_NewSession(_graph, opts ?? lopts, status); | |||
| status.Check(true); | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| status = status ?? new Status(); | |||
| _handle = c_api.TF_NewSession(_graph, opts, status); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| } | |||
| @@ -32,7 +32,7 @@ namespace Tensorflow | |||
| _handle = handle; | |||
| } | |||
| public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s) | |||
| public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) | |||
| { } | |||
| public Session as_default() | |||
| @@ -20,11 +20,14 @@ using System.Runtime.InteropServices; | |||
| namespace Tensorflow | |||
| { | |||
| public class SessionOptions : DisposableObject | |||
| internal class SessionOptions : DisposableObject | |||
| { | |||
| public SessionOptions() | |||
| public SessionOptions(string target = "", ConfigProto config = null) | |||
| { | |||
| _handle = c_api.TF_NewSessionOptions(); | |||
| c_api.TF_SetTarget(_handle, target); | |||
| if (config != null) | |||
| SetConfig(config); | |||
| } | |||
| public SessionOptions(IntPtr handle) | |||
| @@ -35,10 +38,10 @@ namespace Tensorflow | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TF_DeleteSessionOptions(handle); | |||
| public void SetConfig(ConfigProto config) | |||
| private void SetConfig(ConfigProto config) | |||
| { | |||
| var bytes = config.ToByteArray(); //TODO! we can use WriteTo | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak | |||
| var bytes = config.ToByteArray(); | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
| using (var status = new Status()) | |||
| @@ -1,10 +0,0 @@ | |||
| using System.Runtime.InteropServices; | |||
| namespace Tensorflow | |||
| { | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct TF_SessionOptions | |||
| { | |||
| public SessionOptions options; | |||
| } | |||
| } | |||
| @@ -28,9 +28,9 @@ namespace Tensorflow | |||
| { | |||
| private Func<List<NDArray>, object> _contraction_fn; | |||
| public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn) | |||
| public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn, Graph graph = null) | |||
| { | |||
| var g = ops.get_default_graph(); | |||
| var g = graph ?? ops.get_default_graph(); | |||
| foreach(var fetch in fetches) | |||
| { | |||
| @@ -34,7 +34,7 @@ namespace Tensorflow | |||
| public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||
| { | |||
| _fetch_mapper = _FetchMapper.for_fetch(fetches); | |||
| _fetch_mapper = _FetchMapper.for_fetch(fetches, graph: graph); | |||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
| { | |||
| switch (fetch) | |||
| @@ -25,7 +25,7 @@ namespace Tensorflow | |||
| { | |||
| protected List<ITensorOrOperation> _unique_fetches = new List<ITensorOrOperation>(); | |||
| protected List<int[]> _value_indices = new List<int[]>(); | |||
| public static _FetchMapper for_fetch(object fetch) | |||
| public static _FetchMapper for_fetch(object fetch, Graph graph = null) | |||
| { | |||
| var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; | |||
| @@ -34,7 +34,7 @@ namespace Tensorflow | |||
| if (fetch.GetType().IsArray) | |||
| return new _ListFetchMapper(fetches); | |||
| else | |||
| return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0]); | |||
| return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0], graph: graph); | |||
| } | |||
| public virtual NDArray[] build_results(List<NDArray> values) | |||
| @@ -116,6 +116,9 @@ namespace Tensorflow | |||
| /// <param name="proto_len">size_t</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status); | |||
| public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetTarget(IntPtr options, string target); | |||
| } | |||
| } | |||
| @@ -5,7 +5,7 @@ | |||
| <AssemblyName>TensorFlow.NET</AssemblyName> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <TargetTensorFlow>1.14.0</TargetTensorFlow> | |||
| <Version>0.11.8</Version> | |||
| <Version>0.12.0</Version> | |||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
| @@ -16,25 +16,15 @@ | |||
| <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | |||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | |||
| <Description>Google's TensorFlow full binding in .NET Standard. | |||
| Docs: https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.11.8.0</AssemblyVersion> | |||
| <PackageReleaseNotes>Changes since v0.10.0: | |||
| 1. Upgrade NumSharp to v0.20.3. | |||
| 2. Add DisposableObject class to manage object lifetime. | |||
| 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. | |||
| 4. Change tensorflow to non-static class in order to execute some initialization process. | |||
| 5. Overload session.run(), make syntax simpler. | |||
| 6. Add Local Response Normalization. | |||
| 7. Add tf.image related APIs. | |||
| 8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. | |||
| 9. MultiThread is safe. | |||
| 10. Support n-dim indexing for tensor. | |||
| 11. Add RegisterNoGradients | |||
| 12. Add CumsumGrad, BroadcastToGrad. | |||
| 13. Return VariableV1 instead of RefVariable. | |||
| 14. Add Tensor overload to GradientDescentOptimizer.</PackageReleaseNotes> | |||
| Building, training and infering deep learning models. | |||
| https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.12.0.0</AssemblyVersion> | |||
| <PackageReleaseNotes>Changes since v0.11.0: | |||
| 1: Add ICanBeFlattened for nest.flatten2. | |||
| 2: Complete the WhileContext. | |||
| 3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn.</PackageReleaseNotes> | |||
| <LangVersion>7.3</LangVersion> | |||
| <FileVersion>0.11.8.0</FileVersion> | |||
| <FileVersion>0.12.0.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
| <SignAssembly>true</SignAssembly> | |||
| @@ -43,7 +33,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
| <DefineConstants>TRACE;DEBUG;SERIALIZABLE</DefineConstants> | |||
| <DefineConstants>TRACE;DEBUG;SERIALIZABLE_</DefineConstants> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
| @@ -65,8 +55,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="Google.Protobuf" Version="3.5.1" /> | |||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | |||
| <PackageReference Include="Google.Protobuf" Version="3.10.0" /> | |||
| <PackageReference Include="NumSharp" Version="0.20.4" /> | |||
| </ItemGroup> | |||
| @@ -25,7 +25,9 @@ using System.Text; | |||
| using NumSharp.Backends; | |||
| using NumSharp.Backends.Unmanaged; | |||
| using static Tensorflow.c_api; | |||
| #if SERIALIZABLE | |||
| using Newtonsoft.Json; | |||
| #endif | |||
| namespace Tensorflow | |||
| { | |||
| @@ -0,0 +1,15 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Tensor | |||
| { | |||
| public object[] Flatten() | |||
| { | |||
| return new Tensor[] { this }; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,15 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Tensor | |||
| { | |||
| public Tensor Pack(object[] sequences) | |||
| { | |||
| return sequences[0] as Tensor; | |||
| } | |||
| } | |||
| } | |||
| @@ -28,7 +28,9 @@ using NumSharp.Backends; | |||
| using NumSharp.Backends.Unmanaged; | |||
| using NumSharp.Utilities; | |||
| using Tensorflow.Framework; | |||
| #if SERIALIZABLE | |||
| using Newtonsoft.Json; | |||
| #endif | |||
| namespace Tensorflow | |||
| { | |||
| @@ -37,7 +39,12 @@ namespace Tensorflow | |||
| /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | |||
| /// </summary> | |||
| [SuppressMessage("ReSharper", "ConvertToAutoProperty")] | |||
| public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike | |||
| public partial class Tensor : DisposableObject, | |||
| ITensorOrOperation, | |||
| _TensorLike, | |||
| ITensorOrTensorArray, | |||
| IPackable<Tensor>, | |||
| ICanBeFlattened | |||
| { | |||
| private readonly int _id; | |||
| private readonly Operation _op; | |||
| @@ -95,7 +102,7 @@ namespace Tensorflow | |||
| [JsonIgnore] | |||
| #endif | |||
| public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | |||
| private IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||
| public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||
| public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | |||
| #if SERIALIZABLE | |||
| [JsonIgnore] | |||
| @@ -176,7 +183,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public void set_shape(TensorShape shape) | |||
| { | |||
| this.shape = shape.rank > 0 ? shape.dims : null; | |||
| this.shape = shape.rank >= 0 ? shape.dims : null; | |||
| } | |||
| /// <summary> | |||
| @@ -17,8 +17,9 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow.Operations | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// TensorArray is designed to hide an underlying implementation object | |||
| @@ -29,9 +30,9 @@ namespace Tensorflow.Operations | |||
| /// `while_loop` and `map_fn`. It supports gradient back-propagation via special | |||
| /// "flow" control flow dependencies. | |||
| /// </summary> | |||
| public class TensorArray | |||
| public class TensorArray : ITensorOrTensorArray | |||
| { | |||
| _GraphTensorArray _implementation; | |||
| internal _GraphTensorArray _implementation; | |||
| public TF_DataType dtype => _implementation._dtype; | |||
| public Tensor handle => _implementation._handle; | |||
| @@ -39,7 +40,7 @@ namespace Tensorflow.Operations | |||
| public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, | |||
| string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
| bool infer_shape = true, TensorShape[] element_shape = null, | |||
| bool infer_shape = true, TensorShape element_shape = null, | |||
| bool colocate_with_first_write_call = true, string name = null) | |||
| { | |||
| _implementation = new _GraphTensorArray(dtype, | |||
| @@ -57,5 +58,14 @@ namespace Tensorflow.Operations | |||
| public TensorArray unstack(Tensor value, string name = null) | |||
| => _implementation.unstack(value, name: name); | |||
| public Tensor read(Tensor index, string name = null) | |||
| => _implementation.read(index, name: name); | |||
| public TensorArray write(Tensor index, Tensor value, string name = null) | |||
| => _implementation.write(index, value, name: name); | |||
| public Tensor stack(string name = null) | |||
| => _implementation.stack(name: name); | |||
| } | |||
| } | |||
| @@ -1,10 +1,12 @@ | |||
| using Newtonsoft.Json; | |||
| using NumSharp; | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| #if SERIALIZABLE | |||
| using Newtonsoft.Json; | |||
| #endif | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -123,6 +125,9 @@ namespace Tensorflow | |||
| { | |||
| get | |||
| { | |||
| if (!slice.Stop.HasValue) | |||
| slice.Stop = dims.Length - slice.Start + 1; | |||
| if (slice.Start.HasValue == false || slice.Length.HasValue == false) | |||
| throw new ArgumentException("Slice must has Start and Length."); | |||
| @@ -33,6 +33,9 @@ namespace Tensorflow | |||
| public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | |||
| public static TF_DataType float16 = TF_DataType.TF_HALF; | |||
| public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | |||
| public static TF_DataType complex = TF_DataType.TF_COMPLEX; | |||
| public static TF_DataType complex64 = TF_DataType.TF_COMPLEX64; | |||
| public static TF_DataType complex128 = TF_DataType.TF_COMPLEX128; | |||
| public static TF_DataType variant = TF_DataType.TF_VARIANT; | |||
| public static TF_DataType resource = TF_DataType.TF_RESOURCE; | |||
| @@ -335,5 +335,10 @@ namespace Tensorflow | |||
| return shape; | |||
| } | |||
| public static Tensor shape_tensor(int[] shape) | |||
| { | |||
| return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape"); | |||
| } | |||
| } | |||
| } | |||
| @@ -35,22 +35,29 @@ namespace Tensorflow.Train | |||
| /// for changing these values across different invocations of optimizer | |||
| /// functions. | |||
| /// </remarks> | |||
| private bool _useTensor; | |||
| public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") | |||
| : base(learning_rate, use_locking, name) | |||
| { | |||
| _lr = learning_rate; | |||
| _useTensor = false; | |||
| } | |||
| public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") | |||
| : base(learning_rate, use_locking, name) | |||
| { | |||
| _lr_t = learning_rate; | |||
| _useTensor = true; | |||
| } | |||
| public override void _prepare() | |||
| { | |||
| var lr = _call_if_callable(_lr); | |||
| _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); | |||
| if(!_useTensor) | |||
| { | |||
| var lr = _call_if_callable(_lr); | |||
| _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -19,6 +19,7 @@ using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using NumSharp; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow.Util | |||
| { | |||
| @@ -221,9 +222,14 @@ namespace Tensorflow.Util | |||
| return list; | |||
| } | |||
| public static object[] flatten2(ICanBeFlattened structure) | |||
| => structure.Flatten(); | |||
| public static T[] flatten2<T>(T[] structure) | |||
| => structure; | |||
| private static void _flatten_recursive<T>(T obj, List<T> list) | |||
| { | |||
| switch(obj) | |||
| { | |||
| case IDictionary dict: | |||
| @@ -395,6 +401,10 @@ namespace Tensorflow.Util | |||
| private static int len(IEnumerable<object> x) => x.Count(); | |||
| public static T pack_sequence_as2<T>(T structure, object[] flat_sequence, bool expand_composites = false) | |||
| where T : IPackable<T> | |||
| => structure.Pack(flat_sequence); | |||
| /// <summary> | |||
| /// Returns a given flattened sequence packed into a given structure. | |||
| /// If `structure` is a scalar, `flat_sequence` must be a single-element list; | |||
| @@ -418,7 +428,7 @@ namespace Tensorflow.Util | |||
| /// <returns> `flat_sequence` converted to have the same recursive structure as | |||
| /// `structure`. | |||
| /// </returns> | |||
| public static object pack_sequence_as(object structure, IEnumerable<object> flat_sequence) | |||
| public static object pack_sequence_as(object structure, IEnumerable<object> flat_sequence, bool expand_composites = false) | |||
| { | |||
| List<object> flat = null; | |||
| if (flat_sequence is List<object>) | |||
| @@ -516,6 +526,14 @@ namespace Tensorflow.Util | |||
| return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||
| } | |||
| public static Tensor map_structure2<T>(Func<T, Tensor> func, T structure) | |||
| { | |||
| var flat_structure = flatten(structure); | |||
| var mapped_flat_structure = flat_structure.Select(func).ToList(); | |||
| return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||
| } | |||
| /// <summary> | |||
| /// Same as map_structure, but with only one structure (no combining of multiple structures) | |||
| /// </summary> | |||
| @@ -133,66 +133,69 @@ namespace Tensorflow | |||
| if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | |||
| collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | |||
| ops.init_scope(); | |||
| var values = init_from_fn ? new object[0] : new object[] { initial_value }; | |||
| tf_with(ops.name_scope(name, "Variable", values), scope => | |||
| tf_with(ops.init_scope2(), delegate | |||
| { | |||
| name = scope; | |||
| if (init_from_fn) | |||
| var values = init_from_fn ? new object[0] : new object[] { initial_value }; | |||
| tf_with(ops.name_scope(name, "Variable", values), scope => | |||
| { | |||
| // Use attr_scope and device(None) to simulate the behavior of | |||
| // colocate_with when the variable we want to colocate with doesn't | |||
| // yet exist. | |||
| string true_name = ops.name_from_scope_name(name); | |||
| var attr = new AttrValue | |||
| name = scope; | |||
| if (init_from_fn) | |||
| { | |||
| List = new AttrValue.Types.ListValue() | |||
| }; | |||
| attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); | |||
| tf_with(ops.name_scope("Initializer"), scope2 => | |||
| // Use attr_scope and device(None) to simulate the behavior of | |||
| // colocate_with when the variable we want to colocate with doesn't | |||
| // yet exist. | |||
| string true_name = ops.name_from_scope_name(name); | |||
| var attr = new AttrValue | |||
| { | |||
| List = new AttrValue.Types.ListValue() | |||
| }; | |||
| attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); | |||
| tf_with(ops.name_scope("Initializer"), scope2 => | |||
| { | |||
| _initial_value = (initial_value as Func<Tensor>)(); | |||
| _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); | |||
| }); | |||
| _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||
| } | |||
| // Or get the initial value from a Tensor or Python object. | |||
| else | |||
| { | |||
| _initial_value = (initial_value as Func<Tensor>)(); | |||
| _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); | |||
| }); | |||
| _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||
| } | |||
| // Or get the initial value from a Tensor or Python object. | |||
| else | |||
| { | |||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); | |||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); | |||
| var shape = _initial_value.shape; | |||
| dtype = _initial_value.dtype; | |||
| _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); | |||
| } | |||
| var shape = _initial_value.shape; | |||
| dtype = _initial_value.dtype; | |||
| _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); | |||
| } | |||
| // Manually overrides the variable's shape with the initial value's. | |||
| if (validate_shape) | |||
| { | |||
| var initial_value_shape = _initial_value.TensorShape; | |||
| if (!initial_value_shape.is_fully_defined()) | |||
| throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); | |||
| } | |||
| // Manually overrides the variable's shape with the initial value's. | |||
| if (validate_shape) | |||
| { | |||
| var initial_value_shape = _initial_value.TensorShape; | |||
| if (!initial_value_shape.is_fully_defined()) | |||
| throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); | |||
| } | |||
| // If 'initial_value' makes use of other variables, make sure we don't | |||
| // have an issue if these other variables aren't initialized first by | |||
| // using their initialized_value() method. | |||
| var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); | |||
| // If 'initial_value' makes use of other variables, make sure we don't | |||
| // have an issue if these other variables aren't initialized first by | |||
| // using their initialized_value() method. | |||
| var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); | |||
| _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; | |||
| _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; | |||
| if (!String.IsNullOrEmpty(caching_device)) | |||
| { | |||
| if (!String.IsNullOrEmpty(caching_device)) | |||
| { | |||
| } | |||
| else | |||
| { | |||
| ops.colocate_with(_initializer_op); | |||
| } | |||
| else | |||
| { | |||
| ops.colocate_with(_initializer_op); | |||
| _snapshot = gen_array_ops.identity(_variable, name = "read"); | |||
| } | |||
| _snapshot = gen_array_ops.identity(_variable, name = "read"); | |||
| } | |||
| ops.add_to_collections(collections, this as VariableV1); | |||
| ops.add_to_collections(collections, this as VariableV1); | |||
| }); | |||
| }); | |||
| } | |||
| @@ -186,12 +186,7 @@ namespace Tensorflow | |||
| /// operations constructed within the context. | |||
| /// </returns> | |||
| public static _ControlDependenciesController control_dependencies(object[] control_inputs) | |||
| { | |||
| return get_default_graph().control_dependencies(control_inputs); | |||
| } | |||
| public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
| => control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray()); | |||
| => get_default_graph().control_dependencies(control_inputs); | |||
| /// <summary> | |||
| /// Creates a TF_Operation. | |||
| @@ -212,9 +207,9 @@ namespace Tensorflow | |||
| { | |||
| var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||
| //TODO: Implement TF_SetDevice | |||
| //if node_def.device: | |||
| // c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) | |||
| if (!string.IsNullOrEmpty(node_def.Device)) | |||
| c_api.TF_SetDevice(op_desc, node_def.Device); | |||
| // Add inputs | |||
| foreach (var op_input in inputs) | |||
| { | |||
| @@ -310,6 +305,22 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| public static IObjectLife init_scope2() | |||
| { | |||
| // Retrieve the active name scope: entering an `init_scope` preserves | |||
| // the name scope of the current context. | |||
| var default_graph = get_default_graph(); | |||
| var scope = default_graph.get_name_scope(); | |||
| if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) | |||
| // Names that end with trailing slashes are treated by `name_scope` as | |||
| // absolute. | |||
| scope += "/"; | |||
| // inner_device_stack = default_graph._device_function_stack | |||
| // var outer_context = default_graph.as_default; | |||
| return ops.control_dependencies(null); | |||
| } | |||
| private static int uid_number = 0; | |||
| /// <summary> | |||
| @@ -508,6 +519,8 @@ namespace Tensorflow | |||
| return null; | |||
| case TensorShape ts: | |||
| return constant_op.constant(ts.dims, dtype: dtype, name: name); | |||
| case int[] dims: | |||
| return constant_op.constant(dims, dtype: dtype, name: name); | |||
| case object[] objects: | |||
| return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name); | |||
| default: | |||
| @@ -45,7 +45,10 @@ namespace Tensorflow | |||
| public void __enter__() | |||
| { | |||
| _name = _name ?? _default_name; | |||
| if (_name.EndsWith("basic_r_n_n_cell")) | |||
| { | |||
| } | |||
| Graph g = null; | |||
| if (_values is List<Tensor> vList) | |||
| @@ -93,14 +93,14 @@ namespace Tensorflow | |||
| return new Session().as_default(); | |||
| } | |||
| public Session Session(Graph graph, SessionOptions opts = null) | |||
| public Session Session(Graph graph, ConfigProto config = null) | |||
| { | |||
| return new Session(graph, opts: opts).as_default(); | |||
| return new Session(graph, config: config).as_default(); | |||
| } | |||
| public Session Session(SessionOptions opts) | |||
| public Session Session(ConfigProto config) | |||
| { | |||
| return new Session(null, opts).as_default(); | |||
| return new Session(null, config).as_default(); | |||
| } | |||
| public void __init__() | |||
| @@ -25,9 +25,8 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| var opts = new SessionOptions(); | |||
| opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); | |||
| session_ = new Session(graph, opts, s); | |||
| var config = new ConfigProto {InterOpParallelismThreads = 4}; | |||
| session_ = new Session(graph, config, s); | |||
| } | |||
| } | |||
| @@ -18,10 +18,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| var i = constant_op.constant(0, name: "i"); | |||
| var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | |||
| var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | |||
| var r = control_flow_ops.while_loop(c, b, new[] { i }); | |||
| var r = control_flow_ops.while_loop(c, b, i); | |||
| } | |||
| private void _testWhileContextHelper(int? maximum_iterations = null) | |||
| private void _testWhileContextHelper(int maximum_iterations) | |||
| { | |||
| // TODO: implement missing code dependencies | |||
| using (var sess = this.cached_session()) | |||
| @@ -30,7 +30,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||
| var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||
| control_flow_ops.while_loop( | |||
| c, b, new[] { i }, maximum_iterations: maximum_iterations); | |||
| c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||
| foreach (Operation op in sess.graph.get_operations()) | |||
| { | |||
| var control_flow_context = op._get_control_flow_context(); | |||
| @@ -42,13 +42,6 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| } | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testWhileContext() | |||
| { | |||
| _testWhileContextHelper(); | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testWhileContextWithMaximumIterations() | |||