Browse Source

Merge branch 'master' into v0.20-tensorflow2.0

tags/v0.20
Oceania2018 6 years ago
parent
commit
f1a80aac1c
80 changed files with 3095 additions and 1030 deletions
  1. +8
    -1
      README.md
  2. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.control_flow.cs
  3. +4
    -4
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  4. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  5. +24
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  6. +32
    -0
      src/TensorFlowNET.Core/Device/c_api.device.cs
  7. +71
    -67
      src/TensorFlowNET.Core/Gradients/control_flow_grad.cs
  8. +142
    -50
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  9. +1
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  10. +18
    -13
      src/TensorFlowNET.Core/Graphs/Graph.cs
  11. +4
    -1
      src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
  12. +11
    -0
      src/TensorFlowNET.Core/Interfaces/IFlatten.cs
  13. +0
    -0
      src/TensorFlowNET.Core/Interfaces/IObjectLife.cs
  14. +11
    -0
      src/TensorFlowNET.Core/Interfaces/IPackable.cs
  15. +0
    -0
      src/TensorFlowNET.Core/Interfaces/ITensorOrOperation.cs
  16. +7
    -6
      src/TensorFlowNET.Core/Interfaces/ITensorOrTensorArray.cs
  17. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  18. +3
    -3
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  19. +3
    -3
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  20. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  21. +18
    -9
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  22. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
  23. +15
    -5
      src/TensorFlowNET.Core/Layers/Layer.cs
  24. +33
    -2
      src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
  25. +73
    -25
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  26. +197
    -162
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs
  27. +240
    -317
      src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
  28. +43
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs
  29. +36
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs
  30. +445
    -32
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  31. +1
    -0
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  32. +2
    -2
      src/TensorFlowNET.Core/Operations/LayerRNNCell.cs
  33. +49
    -0
      src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs
  34. +21
    -1
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  35. +76
    -15
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  36. +31
    -2
      src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
  37. +14
    -0
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  38. +12
    -5
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  39. +8
    -6
      src/TensorFlowNET.Core/Operations/Operation.Input.cs
  40. +13
    -8
      src/TensorFlowNET.Core/Operations/Operation.Instance.cs
  41. +3
    -1
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  42. +10
    -8
      src/TensorFlowNET.Core/Operations/Operation.cs
  43. +2
    -2
      src/TensorFlowNET.Core/Operations/RNNCell.cs
  44. +78
    -7
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  45. +126
    -34
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  46. +133
    -0
      src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
  47. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  48. +15
    -4
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs
  49. +100
    -5
      src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
  50. +4
    -2
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  51. +19
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  52. +8
    -0
      src/TensorFlowNET.Core/Operations/nn_ops.cs
  53. +1
    -1
      src/TensorFlowNET.Core/Operations/random_ops.py.cs
  54. +52
    -0
      src/TensorFlowNET.Core/Operations/tensor_array_ops.cs
  55. +645
    -78
      src/TensorFlowNET.Core/Protobuf/Config.cs
  56. +8
    -7
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  57. +1
    -1
      src/TensorFlowNET.Core/Sessions/Session.cs
  58. +8
    -5
      src/TensorFlowNET.Core/Sessions/SessionOptions.cs
  59. +0
    -10
      src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs
  60. +2
    -2
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  61. +1
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  62. +2
    -2
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  63. +4
    -1
      src/TensorFlowNET.Core/Sessions/c_api.session.cs
  64. +11
    -22
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  65. +2
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  66. +15
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs
  67. +15
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Pack.cs
  68. +10
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  69. +14
    -4
      src/TensorFlowNET.Core/Tensors/TensorArray.cs
  70. +7
    -2
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  71. +3
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  72. +5
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  73. +9
    -2
      src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs
  74. +20
    -2
      src/TensorFlowNET.Core/Util/nest.py.cs
  75. +51
    -48
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  76. +22
    -9
      src/TensorFlowNET.Core/ops.cs
  77. +3
    -0
      src/TensorFlowNET.Core/ops.name_scope.cs
  78. +4
    -4
      src/TensorFlowNET.Core/tensorflow.cs
  79. +2
    -3
      test/TensorFlowNET.UnitTest/CSession.cs
  80. +3
    -10
      test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs

+ 8
- 1
README.md View File

@@ -9,7 +9,7 @@
[![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) [![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest)
[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US) [![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US)


TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). <a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp_badge.png" width="200" height="200" align="right" /></a>
TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).




![tensors_flowing](docs/assets/tensors_flowing.gif) ![tensors_flowing](docs/assets/tensors_flowing.gif)
@@ -26,6 +26,13 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr


### How to use ### How to use


| TensorFlow | tf 1.13 | tf 1.14 | tf 1.15 | tf 2.0 |
| ----------- | ------- | ------- | ------- | ------ |
| tf.net 0.12 | | x | | |
| tf.net 0.11 | x | x | | |
| tf.net 0.10 | x | x | | |
| tf.net 0.9 | x | | | |

Install TF.NET and TensorFlow binary through NuGet. Install TF.NET and TensorFlow binary through NuGet.
```sh ```sh
### install tensorflow C# binding ### install tensorflow C# binding


+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.control_flow.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow
public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation
=> control_flow_ops.group(inputs, name: name); => control_flow_ops.group(inputs, name: name);


public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
/*public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
TensorShape shape_invariants = null, TensorShape shape_invariants = null,
int parallel_iterations = 10, int parallel_iterations = 10,
bool back_prop = true, bool back_prop = true,
@@ -52,7 +52,7 @@ namespace Tensorflow
swap_memory: swap_memory, swap_memory: swap_memory,
name: name, name: name,
maximum_iterations: maximum_iterations, maximum_iterations: maximum_iterations,
return_same_structure: return_same_structure);
return_same_structure: return_same_structure);*/


public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
=> ops.control_dependencies(control_inputs); => ops.control_dependencies(control_inputs);


+ 4
- 4
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -63,7 +63,7 @@ namespace Tensorflow
trainable: trainable, trainable: trainable,
name: name); name: name);


return layer.apply(inputs);
return layer.apply(inputs).Item1;
} }


/// <summary> /// <summary>
@@ -117,7 +117,7 @@ namespace Tensorflow
trainable: trainable, trainable: trainable,
name: name); name: name);


return layer.apply(inputs, training: training);
return layer.apply(inputs, training: training).Item1;
} }


/// <summary> /// <summary>
@@ -143,7 +143,7 @@ namespace Tensorflow
data_format: data_format, data_format: data_format,
name: name); name: name);


return layer.apply(inputs);
return layer.apply(inputs).Item1;
} }


/// <summary> /// <summary>
@@ -179,7 +179,7 @@ namespace Tensorflow
kernel_initializer: kernel_initializer, kernel_initializer: kernel_initializer,
trainable: trainable); trainable: trainable);


return layer.apply(inputs);
return layer.apply(inputs).Item1;
} }


/// <summary> /// <summary>


+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -76,7 +76,7 @@ namespace Tensorflow
/// <param name="swap_memory"></param> /// <param name="swap_memory"></param>
/// <param name="time_major"></param> /// <param name="time_major"></param>
/// <returns>A pair (outputs, state)</returns> /// <returns>A pair (outputs, state)</returns>
public (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
public (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs,
Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
=> rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype,
@@ -134,7 +134,7 @@ namespace Tensorflow
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);


public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK") public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK")
=> gen_ops.in_top_k(predictions, targets, k, name);
=> nn_ops.in_top_k(predictions, targets, k, name);


public Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) public Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null)
=> gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name);


+ 24
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -30,6 +30,20 @@ namespace Tensorflow
/// </summary> /// </summary>
public static partial class Binding public static partial class Binding
{ {
public static T2 get<T1, T2>(this Dictionary<T1, T2> dict, T1 key)
=> key == null ?
default(T2) :
(dict.ContainsKey(key) ? dict[key] : default(T2));

public static void add<T>(this IList<T> list, T element)
=> list.Add(element);

public static void append<T>(this IList<T> list, T element)
=> list.Add(element);

public static void extend<T>(this List<T> list, IEnumerable<T> elements)
=> list.AddRange(elements);

private static string _tostring(object obj) private static string _tostring(object obj)
{ {
switch (obj) switch (obj)
@@ -81,6 +95,9 @@ namespace Tensorflow
throw new NotImplementedException("len() not implemented for type: " + a.GetType()); throw new NotImplementedException("len() not implemented for type: " + a.GetType());
} }


public static T[] list<T>(IEnumerable<T> list)
=> list.ToArray();

public static IEnumerable<int> range(int end) public static IEnumerable<int> range(int end)
{ {
return Enumerable.Range(0, end); return Enumerable.Range(0, end);
@@ -165,6 +182,12 @@ namespace Tensorflow
yield return (t1[i], t2[i]); yield return (t1[i], t2[i]);
} }


public static IEnumerable<(T1, T2, T3)> zip<T1, T2, T3>(IList<T1> t1, IList<T2> t2, IList<T3> t3)
{
for (int i = 0; i < t1.Count; i++)
yield return (t1[i], t2[i], t3[i]);
}

public static IEnumerable<(T1, T2)> zip<T1, T2>(NDArray t1, NDArray t2) public static IEnumerable<(T1, T2)> zip<T1, T2>(NDArray t1, NDArray t2)
where T1: unmanaged where T1: unmanaged
where T2: unmanaged where T2: unmanaged
@@ -203,6 +226,7 @@ namespace Tensorflow
yield return (i, values[i]); yield return (i, values[i]);
} }


[DebuggerStepThrough]
public static Dictionary<string, object> ConvertToDict(object dyn) public static Dictionary<string, object> ConvertToDict(object dyn)
{ {
var dictionary = new Dictionary<string, object>(); var dictionary = new Dictionary<string, object>();


+ 32
- 0
src/TensorFlowNET.Core/Device/c_api.device.cs View File

@@ -0,0 +1,32 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Runtime.InteropServices;

namespace Tensorflow
{
public partial class c_api
{
/// <summary>
/// Specify the device for `desc`. Defaults to empty, meaning unconstrained.
/// </summary>
/// <param name="desc"></param>
/// <param name="device"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_SetDevice(IntPtr desc, string device);
}
}

+ 71
- 67
src/TensorFlowNET.Core/Gradients/control_flow_grad.cs View File

@@ -45,7 +45,24 @@ namespace Tensorflow.Gradients
switch (op_ctxt) switch (op_ctxt)
{ {
case WhileContext cwhile: case WhileContext cwhile:
throw new NotImplementedException("_SwitchGrad WhileContext");
{
var merge_grad = grad_ctxt.grad_state.switch_map.get(op);
if (merge_grad != null)
{
if (grads[1] != null)
control_flow_ops._AddNextAndBackEdge(merge_grad, grads[1],
enforce_shape_invariant: false);
return new Tensor[] { null, null };
}
else if (grads[0] != null)
{
merge_grad = merge(new[] { grads[0], grads[0] }, name: "b_switch")[0];
grad_ctxt.grad_state.switch_map[op] = merge_grad;
return new Tensor[] { merge_grad, null };
}
else
return new Tensor[] { null, null };
}
case CondContext ccond: case CondContext ccond:
{ {
var zero_grad = grads[1 - op_ctxt.branch]; var zero_grad = grads[1 - op_ctxt.branch];
@@ -74,7 +91,7 @@ namespace Tensorflow.Gradients
/// <param name="inputs"></param> /// <param name="inputs"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
internal static Tensor[] merge(Tensor[] inputs, string name = null)
internal static MergeOutput merge(Tensor[] inputs, string name = null)
{ {
return tf_with(ops.name_scope(name, "Merge", inputs), scope => return tf_with(ops.name_scope(name, "Merge", inputs), scope =>
{ {
@@ -146,7 +163,7 @@ namespace Tensorflow.Gradients
} }


[RegisterGradient("RefMerge")] [RegisterGradient("RefMerge")]
public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
public static Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
{ {
return _MergeGrad(op, grads); return _MergeGrad(op, grads);
} }
@@ -155,43 +172,32 @@ namespace Tensorflow.Gradients
/// Gradients for an exit op are calculated using an Enter op. /// Gradients for an exit op are calculated using an Enter op.
/// </summary> /// </summary>
[RegisterGradient("Exit")] [RegisterGradient("Exit")]
public Tensor[] _ExitGrad(Operation op, Tensor[] grads)
public static Tensor[] _ExitGrad(Operation op, Tensor[] grads)
{ {
throw new NotImplementedException("_ExitGrad");
// graph = ops.get_default_graph()
//# pylint: disable=protected-access
// op_ctxt = op._get_control_flow_context()
// grad_ctxt = graph._get_control_flow_context()
// # pylint: enable=protected-access
// if not grad_ctxt.back_prop:
// # The flag `back_prop` is set by users to suppress gradient
// # computation for this loop. If the attribute `back_prop` is false,
// # no gradient computation.
// return None
var grad = grads[0];
var graph = ops.get_default_graph();
var op_ctxt = op._get_control_flow_context();
var grad_ctxt = graph._get_control_flow_context() as WhileContext;
// The flag `back_prop` is set by users to suppress gradient
// computation for this loop. If the attribute `back_prop` is false,
// no gradient computation.
if (!grad_ctxt.back_prop)
return null;


// if op_ctxt.grad_state:
// raise TypeError("Second-order gradient for while loops not supported.")
if (op_ctxt.grad_state != null)
throw new TypeError("Second-order gradient for while loops not supported.");


// if isinstance(grad, ops.Tensor) :
// grad_ctxt.AddName(grad.name)
// else:
// if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
// raise TypeError("Type %s not supported" % type(grad))
// grad_ctxt.AddName(grad.values.name)
// grad_ctxt.AddName(grad.indices.name)
// dense_shape = grad.dense_shape
// if dense_shape is not None:
// grad_ctxt.AddName(dense_shape.name)
// grad_ctxt.Enter()
// # pylint: disable=protected-access
// result = control_flow_ops._Enter(
// grad, grad_ctxt.name, is_constant=False,
// parallel_iterations=grad_ctxt.parallel_iterations,
// name="b_exit")
// # pylint: enable=protected-access
// grad_ctxt.loop_enters.append(result)
// grad_ctxt.Exit()
// return result
grad_ctxt.AddName(grad.name);

grad_ctxt.Enter();
var result = control_flow_ops._Enter(
grad, grad_ctxt.name, is_constant: false,
parallel_iterations: grad_ctxt.parallel_iterations,
name: "b_exit");

grad_ctxt.loop_enters.append(result);
grad_ctxt.Exit();
return new[] { result };
} }


/// <summary> /// <summary>
@@ -200,15 +206,15 @@ namespace Tensorflow.Gradients
/// Note that the backprop next_iteration is added in switch grad. /// Note that the backprop next_iteration is added in switch grad.
/// </summary> /// </summary>
[RegisterGradient("NextIteration")] [RegisterGradient("NextIteration")]
public Tensor[] _NextIterationGrad(object _, Tensor[] grad)
public static Tensor[] _NextIterationGrad(Operation op, Tensor[] grads)
{ {
return grad;
return grads;
} }


[RegisterGradient("RefNextIteration")] [RegisterGradient("RefNextIteration")]
public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad)
public static Tensor[] _RefNextIterationGrad(Operation op, Tensor[] grads)
{ {
return grad;
return grads;
} }


/// <summary> /// <summary>
@@ -218,33 +224,31 @@ namespace Tensorflow.Gradients
/// For loop invariants, we need to add an accumulator loop. /// For loop invariants, we need to add an accumulator loop.
/// </summary> /// </summary>
[RegisterGradient("Enter")] [RegisterGradient("Enter")]
public Tensor[] _EnterGrad(Tensor op, Tensor[] grad)
public static Tensor[] _EnterGrad(Operation op, Tensor[] grads)
{ {
throw new NotImplementedException("_EnterGrad");
// graph = ops.get_default_graph()
//# pylint: disable=protected-access
// grad_ctxt = graph._get_control_flow_context()
// # pylint: enable=protected-access
// if not grad_ctxt.back_prop:
// # Skip gradient computation, if the attribute `back_prop` is false.
// return grad
// if grad_ctxt.grad_state is None:
// # Pass the gradient through if we are not in a gradient while context.
// return grad
// if op.get_attr("is_constant"):
// # Add a gradient accumulator for each loop invariant.
// if isinstance(grad, ops.Tensor) :
// result = grad_ctxt.AddBackpropAccumulator(op, grad)
// elif isinstance(grad, ops.IndexedSlices) :
// result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
// else:
// # TODO(yuanbyu, lukasr): Add support for SparseTensor.
// raise TypeError("Type %s not supported" % type(grad))
// else:
// result = exit(grad)
// grad_ctxt.loop_exits.append(result)
// grad_ctxt.ExitResult([result])
// return result
Tensor result = null;
var grad = grads[0];
var graph = ops.get_default_graph();
var grad_ctxt = graph._get_control_flow_context() as WhileContext;
if (!grad_ctxt.back_prop)
// Skip gradient computation, if the attribute `back_prop` is false.
return grads;
if (grad_ctxt.grad_state == null)
// Pass the gradient through if we are not in a gradient while context.
return grads;
if (op.get_attr<bool>("is_constant"))
{
// Add a gradient accumulator for each loop invariant.
result = grad_ctxt.AddBackpropAccumulator(op, grad);
}
else
{
result = control_flow_ops.exit(grad);
grad_ctxt.loop_exits.append(result);
grad_ctxt.ExitResult(new[] { result });
}

return new Tensor[] { result };
} }






+ 142
- 50
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -17,6 +17,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Operations.ControlFlows;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -54,6 +55,9 @@ namespace Tensorflow
* is more than one. * is more than one.
**/ **/
var grads = new Dictionary<string, List<List<Tensor>>>(); var grads = new Dictionary<string, List<List<Tensor>>>();
Operation[] reachable_to_ops = null;
ControlFlowState loop_state = null;
Dictionary<string, int> pending_count = null;


tf_with(ops.name_scope(name, "gradients", tf_with(ops.name_scope(name, "gradients",
values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope =>
@@ -80,8 +84,9 @@ namespace Tensorflow
var to_ops = ys.Select(x => x.op).ToList(); var to_ops = ys.Select(x => x.op).ToList();
var from_ops = xs.Select(x => x.op).ToList(); var from_ops = xs.Select(x => x.op).ToList();
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);
(reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);


// Add the initial gradients for the ys.
foreach (var (y, grad_y) in zip(ys, grad_ys)) foreach (var (y, grad_y) in zip(ys, grad_ys))
_SetGrad(grads, y, grad_y); _SetGrad(grads, y, grad_y);


@@ -103,6 +108,16 @@ namespace Tensorflow
} }
} }


if(loop_state != null)
{
var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set);
foreach(var y in loop_exits)
{
//if(IsTrainable(y))
throw new NotImplementedException("");
}
}

var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs);
while (queue.Count > 0) while (queue.Count > 0)
{ {
@@ -110,45 +125,48 @@ namespace Tensorflow
var op = queue.Dequeue(); var op = queue.Dequeue();


_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
//if (loop_state != null)
//loop_state.EnterGradWhileContext(op, before: true);
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method);

Tensor[] in_grads = null;
var is_partitioned_call = _IsPartitionedCall(op);
var is_func_call = false;
var has_out_grads = out_grads.Exists(x => x != null);
if (has_out_grads && !stop_ops.Contains(op))
{ {
// A grad_fn must be defined, either as a function or as None
// for ops that do not have gradients.
if (loop_state != null)
loop_state.EnterGradWhileContext(op, before: true);
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method);
if (loop_state != null)
loop_state.ExitGradWhileContext(op, before: true);


Tensor[] in_grads = null;
Func<Operation, Tensor[], Tensor[]> grad_fn = null; Func<Operation, Tensor[], Tensor[]> grad_fn = null;
try
{
grad_fn = ops.get_gradient_function(op);
}
catch (LookupError)
var is_partitioned_call = _IsPartitionedCall(op);
var is_func_call = false;
var has_out_grads = out_grads.Exists(x => x != null);
if (has_out_grads && !stop_ops.Contains(op))
{ {
if (is_func_call)
// A grad_fn must be defined, either as a function or as None
// for ops that do not have gradients.
try
{ {
if (is_partitioned_call)
grad_fn = ops.get_gradient_function(op);
}
catch (LookupError)
{
if (is_func_call)
{ {
if (is_partitioned_call)
{

}
else
{


}
} }
else else
{ {

throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})");
} }
} }
else
{
throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})");
}
} }


// if (loop_state)
//loop_state.EnterGradWhileContext(op, before: false);
if (loop_state != null)
loop_state.EnterGradWhileContext(op, before: false);


if ((is_func_call || grad_fn != null) && has_out_grads) if ((is_func_call || grad_fn != null) && has_out_grads)
{ {
@@ -164,7 +182,7 @@ namespace Tensorflow
// will use SymbolicGradient get a zero gradient. Gradient // will use SymbolicGradient get a zero gradient. Gradient
// functions should ignore the gradient for other outputs. // functions should ignore the gradient for other outputs.
if (loop_state != null) if (loop_state != null)
;
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) };
else else
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
} }
@@ -198,33 +216,34 @@ namespace Tensorflow
// just propagate a list of None backwards. // just propagate a list of None backwards.
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
} }
}
else
{
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
}


var inputs = _NonEagerInputs(op, xs).ToList();
foreach (var (t_in, in_grad) in zip(inputs, in_grads))
{
if (in_grad != null)
var inputs = _NonEagerInputs(op, xs).ToList();
foreach (var (t_in, in_grad) in zip(inputs, in_grads))
{ {
if (!(in_grad is null) &&
in_grad.Tag == null && // maybe a IndexedSlice
t_in.dtype != TF_DataType.TF_RESOURCE)
if (in_grad != null)
{ {
in_grad.set_shape(t_in.TensorShape);
}
if (!(in_grad is null) &&
in_grad.Tag == null && // maybe a IndexedSlice
t_in.dtype != TF_DataType.TF_RESOURCE)
{
in_grad.set_shape(t_in.TensorShape);
}


_SetGrad(grads, t_in, in_grad);
_SetGrad(grads, t_in, in_grad);
}
} }
}


if (loop_state != null)
loop_state.ExitGradWhileContext(op, before: false);
}
// Update pending count for the inputs of op and enqueue ready ops. // Update pending count for the inputs of op and enqueue ready ops.
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs);
} }
}); });


if (loop_state != null)
loop_state.PostProcessing();
return xs.Select(x => _GetGrad(grads, x)).ToArray(); return xs.Select(x => _GetGrad(grads, x)).ToArray();
} }


@@ -275,7 +294,7 @@ namespace Tensorflow
/// <param name="colocate_gradients_with_ops"></param> /// <param name="colocate_gradients_with_ops"></param>
/// <param name="func_graphs"></param> /// <param name="func_graphs"></param>
/// <param name="xs"></param> /// <param name="xs"></param>
private static (Operation[], Dictionary<string, int>, object) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs)
private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs)
{ {
// Mark reachable ops from from_ops. // Mark reachable ops from from_ops.
var reached_ops = new List<Operation>(); var reached_ops = new List<Operation>();
@@ -308,6 +327,7 @@ namespace Tensorflow
// 'loop_state' is None if there are no while loops. // 'loop_state' is None if there are no while loops.
var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops);


// Initialize pending count for between ops.
var pending_count = new Dictionary<string, int>(); var pending_count = new Dictionary<string, int>();
foreach (var op in between_op_list) foreach (var op in between_op_list)
{ {
@@ -342,7 +362,11 @@ namespace Tensorflow
grads[op.name] = op_grads; grads[op.name] = op_grads;
} }
var t_grads = op_grads[t.value_index]; var t_grads = op_grads[t.value_index];
t_grads.Add(grad);
if (t_grads.Count > 0 &&
control_flow_util.IsLoopSwitch(op))
op_grads[t.value_index][0] = grad;
else
t_grads.Add(grad);
} }


private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs) private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs)
@@ -351,7 +375,8 @@ namespace Tensorflow
yield return op.inputs[i]; yield return op.inputs[i];
} }


private static List<List<Tensor>> _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
private static List<List<Tensor>> _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid,
ControlFlowState loop_state, int aggregation_method = 0)
{ {
var out_grads = _GetGrads(grads, op); var out_grads = _GetGrads(grads, op);


@@ -359,7 +384,10 @@ namespace Tensorflow
{ {
if (loop_state != null) if (loop_state != null)
{ {

if (out_grads.Count > 1 &&
out_grads[1].Count > 0 &&
control_flow_util.IsLoopSwitch(op))
continue;
} }


// Aggregate multiple gradients, and convert [] to None. // Aggregate multiple gradients, and convert [] to None.
@@ -550,7 +578,7 @@ namespace Tensorflow
Operation op, Operation op,
Queue<Operation> queue, Queue<Operation> queue,
Dictionary<string, int> pending_count, Dictionary<string, int> pending_count,
object loop_state,
ControlFlowState loop_state,
Tensor[] xs) Tensor[] xs)
{ {
foreach (var x in _NonEagerInputs(op, xs)) foreach (var x in _NonEagerInputs(op, xs))
@@ -564,14 +592,49 @@ namespace Tensorflow


if (loop_state != null && !ready) if (loop_state != null && !ready)
{ {
ready = pending_count[x.op.name] > 0 && control_flow_util.IsLoopSwitch(x.op);
} }


if (ready) if (ready)
{ {
// if x is an exit without real gradient, defer processing them.
if (control_flow_util.IsLoopExit(x.op)) if (control_flow_util.IsLoopExit(x.op))
{ {

var grad_state = loop_state.GetGradState(x.op, before: false);
grad_state.deferred_exits.append(x);
grad_state.pending_exits_count -= 1;
// We now have all the exits so process them.
if (grad_state.pending_exits_count == 0)
{
var has_not_none_grad = false;
foreach(var y in grad_state.deferred_exits)
{
if (_HasAnyNotNoneGrads(grads, y.op))
{
has_not_none_grad = true;
queue.Enqueue(y.op);
}
else
grad_state.unused_exits.append(y);
}
if (has_not_none_grad)
{
// For an unused exit, if it has trainable outputs, backprop
// a zero gradient. Otherwise, just ignore it.
foreach (var y in grad_state.unused_exits)
{
if (IsTrainable(y))
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y));
queue.Enqueue(y.op);
}
}
else
{
// All exits are "unused" so use None as gradient.
foreach (var y in grad_state.unused_exits)
queue.Enqueue(y.op);
}
}
} }
else else
{ {
@@ -581,6 +644,32 @@ namespace Tensorflow
} }
} }


private static bool IsTrainable(Tensor tensor)
{
var dtype = tensor.dtype.as_base_dtype();
return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128,
dtypes.resource, dtypes.variant}.Contains(dtype);
}

/// <summary>
/// Return true if op has real gradient.
/// </summary>
/// <param name="grads"></param>
/// <param name="op"></param>
/// <returns></returns>
private static bool _HasAnyNotNoneGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op)
{
var out_grads = _GetGrads(grads, op);
foreach(var out_grad in out_grads)
{
if (out_grad.Exists(g => g != null))
return true;
}
return false;
}


private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn)
{ {
scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope;
@@ -589,6 +678,9 @@ namespace Tensorflow


private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op)
{ {
if (op.type == "While" || op.type == "StatelessWhile")
return;

if (grads.Count() != op.inputs._inputs.Count()) if (grads.Count() != op.inputs._inputs.Count())
throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " +
$"inputs {op.inputs._inputs.Count()}"); $"inputs {op.inputs._inputs.Count()}");


+ 1
- 0
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -18,6 +18,7 @@ using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Linq; using System.Linq;
using Tensorflow.Operations; using Tensorflow.Operations;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {


+ 18
- 13
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -75,7 +75,10 @@ namespace Tensorflow
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
/// </summary> /// </summary>
/// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks>
public partial class Graph : DisposableObject//, IEnumerable<Operation>
public partial class Graph : DisposableObject,
#if !SERIALIZABLE
IEnumerable<Operation>
#endif
{ {
private Dictionary<int, ITensorOrOperation> _nodes_by_id; private Dictionary<int, ITensorOrOperation> _nodes_by_id;
public Dictionary<string, ITensorOrOperation> _nodes_by_name; public Dictionary<string, ITensorOrOperation> _nodes_by_name;
@@ -259,15 +262,11 @@ namespace Tensorflow


if (string.IsNullOrEmpty(name)) if (string.IsNullOrEmpty(name))
name = op_type; name = op_type;

// If a names ends with a '/' it is a "name scope" and we use it as-is, // If a names ends with a '/' it is a "name scope" and we use it as-is,
// after removing the trailing '/'. // after removing the trailing '/'.
name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
if (name.Contains("define_loss/bigger_box_loss/mul_13"))
{
}


var input_ops = inputs.Select(x => x.op).ToArray(); var input_ops = inputs.Select(x => x.op).ToArray();
var control_inputs = _control_dependencies_for_inputs(input_ops); var control_inputs = _control_dependencies_for_inputs(input_ops);
@@ -374,7 +373,11 @@ namespace Tensorflow
/// <returns>A string to be passed to `create_op()` that will be used /// <returns>A string to be passed to `create_op()` that will be used
/// to name the operation being created.</returns> /// to name the operation being created.</returns>
public string unique_name(string name, bool mark_as_used = true) public string unique_name(string name, bool mark_as_used = true)
{
{
if (name.EndsWith("basic_r_n_n_cell"))
{
}
if (!String.IsNullOrEmpty(_name_stack)) if (!String.IsNullOrEmpty(_name_stack))
name = _name_stack + "/" + name; name = _name_stack + "/" + name;
// For the sake of checking for names in use, we treat names as case // For the sake of checking for names in use, we treat names as case
@@ -402,7 +405,7 @@ namespace Tensorflow


// Return the new name with the original capitalization of the given name. // Return the new name with the original capitalization of the given name.
name = $"{name}_{i-1}"; name = $"{name}_{i-1}";
}
}
return name; return name;
} }


@@ -524,17 +527,19 @@ namespace Tensorflow
} }
return debugString;*/ return debugString;*/
}

/*private IEnumerable<Operation> GetEnumerable()
}
#if !SERIALIZABLE
private IEnumerable<Operation> GetEnumerable()
=> c_api_util.tf_operations(this); => c_api_util.tf_operations(this);


IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator()
=> GetEnumerable().GetEnumerator(); => GetEnumerable().GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() IEnumerator IEnumerable.GetEnumerator()
=> throw new NotImplementedException();*/

=> throw new NotImplementedException();
#endif
public static implicit operator IntPtr(Graph graph) public static implicit operator IntPtr(Graph graph)
{ {
return graph._handle; return graph._handle;


+ 4
- 1
src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs View File

@@ -16,6 +16,7 @@


using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Operations; using Tensorflow.Operations;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
@@ -68,7 +69,9 @@ namespace Tensorflow
_new_stack = false; _new_stack = false;
} }


_seen_nodes = new List<ITensorOrOperation>();
_seen_nodes = new List<ITensorOrOperation>();
_old_stack = null;
_old_control_flow_context = null;
} }


public void add_op(ITensorOrOperation op) public void add_op(ITensorOrOperation op)


+ 11
- 0
src/TensorFlowNET.Core/Interfaces/IFlatten.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public interface ICanBeFlattened
{
object[] Flatten();
}
}

src/TensorFlowNET.Core/IObjectLife.cs → src/TensorFlowNET.Core/Interfaces/IObjectLife.cs View File


+ 11
- 0
src/TensorFlowNET.Core/Interfaces/IPackable.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public interface IPackable<T>
{
T Pack(object[] sequences);
}
}

src/TensorFlowNET.Core/ITensorOrOperation.cs → src/TensorFlowNET.Core/Interfaces/ITensorOrOperation.cs View File


src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs → src/TensorFlowNET.Core/Interfaces/ITensorOrTensorArray.cs View File

@@ -14,13 +14,14 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Runtime.InteropServices;

namespace Tensorflow.Sessions
namespace Tensorflow
{ {
[StructLayout(LayoutKind.Sequential)]
public struct TF_DeprecatedSession
/// <summary>
/// in order to limit function return value
/// is Tensor or TensorArray
/// </summary>
public interface ITensorOrTensorArray
{ {
Session session;
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -139,14 +139,14 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor call(Tensor inputs, Tensor training = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{ {
Tensor outputs = null; Tensor outputs = null;


if (fused) if (fused)
{ {
outputs = _fused_batch_norm(inputs, training: training); outputs = _fused_batch_norm(inputs, training: training);
return outputs;
return new[] { outputs, outputs };
} }


throw new NotImplementedException("BatchNormalization call"); throw new NotImplementedException("BatchNormalization call");


+ 3
- 3
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor call(Tensor inputs, Tensor training = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{ {
var outputs = _convolution_op.__call__(inputs, kernel); var outputs = _convolution_op.__call__(inputs, kernel);
if (use_bias) if (use_bias)
@@ -124,9 +124,9 @@ namespace Tensorflow.Keras.Layers
} }


if (activation != null) if (activation != null)
return activation.Activate(outputs);
outputs = activation.Activate(outputs);


return outputs;
return new[] { outputs, outputs };
} }
} }
} }

+ 3
- 3
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor call(Tensor inputs, Tensor training = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{ {
Tensor outputs = null; Tensor outputs = null;
var rank = inputs.rank; var rank = inputs.rank;
@@ -88,9 +88,9 @@ namespace Tensorflow.Keras.Layers
if (use_bias) if (use_bias)
outputs = tf.nn.bias_add(outputs, bias); outputs = tf.nn.bias_add(outputs, bias);
if (activation != null) if (activation != null)
return activation.Activate(outputs);
outputs = activation.Activate(outputs);


return outputs;
return new[] { outputs, outputs };
} }
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -50,14 +50,14 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor call(Tensor inputs, Tensor training = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{ {
var dtype = inputs.dtype; var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64) if (dtype != tf.int32 && dtype != tf.int64)
inputs = math_ops.cast(inputs, tf.int32); inputs = math_ops.cast(inputs, tf.int32);


var @out = embedding_ops.embedding_lookup(embeddings, inputs); var @out = embedding_ops.embedding_lookup(embeddings, inputs);
return @out;
return new[] { @out, @out };
} }
} }
} }

+ 18
- 9
src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -52,6 +52,7 @@ namespace Tensorflow.Keras.Layers
protected InputSpec input_spec; protected InputSpec input_spec;
protected bool supports_masking; protected bool supports_masking;
protected List<VariableV1> _trainable_weights; protected List<VariableV1> _trainable_weights;
protected List<VariableV1> _non_trainable_weights;
private string _name; private string _name;
public string name => _name; public string name => _name;
protected string _base_name; protected string _base_name;
@@ -84,6 +85,7 @@ namespace Tensorflow.Keras.Layers


_init_set_name(name); _init_set_name(name);
_trainable_weights = new List<VariableV1>(); _trainable_weights = new List<VariableV1>();
_non_trainable_weights = new List<VariableV1>();
_compute_previous_mask = false; _compute_previous_mask = false;
_updates = new List<Operation>(); _updates = new List<Operation>();


@@ -101,13 +103,14 @@ namespace Tensorflow.Keras.Layers
_inbound_nodes = new List<Node>(); _inbound_nodes = new List<Node>();
} }


public Tensor __call__(Tensor[] inputs,
public Tensor[] __call__(Tensor[] inputs,
Tensor training = null, Tensor training = null,
Tensor state = null,
VariableScope scope = null) VariableScope scope = null)
{ {
var input_list = inputs; var input_list = inputs;
var input = inputs[0]; var input = inputs[0];
Tensor outputs = null;
Tensor[] outputs = null;


// We will attempt to build a TF graph if & only if all inputs are symbolic. // We will attempt to build a TF graph if & only if all inputs are symbolic.
// This is always the case in graph mode. It can also be the case in eager // This is always the case in graph mode. It can also be the case in eager
@@ -139,7 +142,10 @@ namespace Tensorflow.Keras.Layers
// overridden). // overridden).
_maybe_build(inputs[0]); _maybe_build(inputs[0]);


outputs = call(inputs[0], training: training);
outputs = call(inputs[0],
training: training,
state: state);

(input, outputs) = _set_connectivity_metadata_(input, outputs); (input, outputs) = _set_connectivity_metadata_(input, outputs);
_handle_activity_regularization(inputs[0], outputs); _handle_activity_regularization(inputs[0], outputs);
_set_mask_metadata(inputs[0], outputs, null); _set_mask_metadata(inputs[0], outputs, null);
@@ -149,13 +155,13 @@ namespace Tensorflow.Keras.Layers
return outputs; return outputs;
} }


private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
private (Tensor, Tensor[]) _set_connectivity_metadata_(Tensor inputs, Tensor[] outputs)
{ {
//_add_inbound_node(input_tensors: inputs, output_tensors: outputs); //_add_inbound_node(input_tensors: inputs, output_tensors: outputs);
return (inputs, outputs); return (inputs, outputs);
} }


private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
private void _handle_activity_regularization(Tensor inputs, Tensor[] outputs)
{ {
//if(_activity_regularizer != null) //if(_activity_regularizer != null)
{ {
@@ -163,7 +169,7 @@ namespace Tensorflow.Keras.Layers
} }
} }


private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask)
private void _set_mask_metadata(Tensor inputs, Tensor[] outputs, Tensor previous_mask)
{ {


} }
@@ -173,9 +179,9 @@ namespace Tensorflow.Keras.Layers
return null; return null;
} }


protected virtual Tensor call(Tensor inputs, Tensor training = null)
protected virtual Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{ {
return inputs;
throw new NotImplementedException("");
} }


protected virtual string _name_scope() protected virtual string _name_scope()
@@ -233,7 +239,10 @@ namespace Tensorflow.Keras.Layers
initializer: initializer, initializer: initializer,
trainable: trainable.Value); trainable: trainable.Value);
//backend.track_variable(variable); //backend.track_variable(variable);
_trainable_weights.Add(variable);
if (trainable == true)
_trainable_weights.Add(variable);
else
_non_trainable_weights.Add(variable);


return variable; return variable;
} }


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4); this.input_spec = new InputSpec(ndim: 4);
} }


protected override Tensor call(Tensor inputs, Tensor training = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{ {
int[] pool_shape; int[] pool_shape;
if (data_format == "channels_last") if (data_format == "channels_last")
@@ -64,7 +64,7 @@ namespace Tensorflow.Keras.Layers
padding: padding.ToUpper(), padding: padding.ToUpper(),
data_format: conv_utils.convert_data_format(data_format, 4)); data_format: conv_utils.convert_data_format(data_format, 4));


return outputs;
return new[] { outputs, outputs };
} }
} }
} }

+ 15
- 5
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -43,17 +43,20 @@ namespace Tensorflow.Layers


// Avoid an incorrect lint error // Avoid an incorrect lint error
_trainable_weights = new List<VariableV1>(); _trainable_weights = new List<VariableV1>();
_non_trainable_weights = new List<VariableV1>();
this.built = false; this.built = false;
_keras_style = false; _keras_style = false;
} }


public virtual Tensor apply(Tensor inputs, Tensor training = null)
public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null)
{ {
return __call__(inputs, training: training);
var results = __call__(inputs, training: training);
return (results[0], results[1]);
} }


public Tensor __call__(Tensor inputs,
public Tensor[] __call__(Tensor inputs,
Tensor training = null, Tensor training = null,
Tensor state = null,
VariableScope scope = null) VariableScope scope = null)
{ {
_set_scope(scope); _set_scope(scope);
@@ -71,12 +74,14 @@ namespace Tensorflow.Layers
auxiliary_name_scope: false); auxiliary_name_scope: false);
} }


Tensor outputs = null;
Tensor[] outputs = null;
tf_with(scope_context_manager, scope2 => tf_with(scope_context_manager, scope2 =>
{ {
_current_scope = scope2; _current_scope = scope2;
// Actually call layer // Actually call layer
outputs = base.__call__(new Tensor[] { inputs }, training: training);
outputs = base.__call__(new Tensor[] { inputs },
training: training,
state: state);
}); });




@@ -121,6 +126,11 @@ namespace Tensorflow.Layers
Graph init_graph = null; Graph init_graph = null;
VariableV1[] existing_variables = null; VariableV1[] existing_variables = null;


if (synchronization == VariableSynchronization.OnRead)
trainable = false;
else if (!trainable.HasValue)
trainable = true;

if (default_graph.building_function) if (default_graph.building_function)
{ {
throw new NotImplementedException("add_weight"); throw new NotImplementedException("add_weight");


+ 33
- 2
src/TensorFlowNET.Core/Operations/BasicRNNCell.cs View File

@@ -16,18 +16,23 @@


using System; using System;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
public class BasicRNNCell : LayerRNNCell
public class BasicRnnCell : LayerRnnCell
{ {
int _num_units; int _num_units;
Func<Tensor, string, Tensor> _activation; Func<Tensor, string, Tensor> _activation;


public override int state_size => _num_units; public override int state_size => _num_units;
public override int output_size => _num_units; public override int output_size => _num_units;
public VariableV1 _kernel;
string _WEIGHTS_VARIABLE_NAME = "kernel";
public VariableV1 _bias;
string _BIAS_VARIABLE_NAME = "bias";


public BasicRNNCell(int num_units,
public BasicRnnCell(int num_units,
Func<Tensor, string, Tensor> activation = null, Func<Tensor, string, Tensor> activation = null,
bool? reuse = null, bool? reuse = null,
string name = null, string name = null,
@@ -44,5 +49,31 @@ namespace Tensorflow
else else
_activation = activation; _activation = activation;
} }

protected override void build(TensorShape inputs_shape)
{
var input_depth = inputs_shape.dims[inputs_shape.ndim - 1];

_kernel = add_weight(
_WEIGHTS_VARIABLE_NAME,
shape: new[] { input_depth + _num_units, _num_units });

_bias = add_weight(
_BIAS_VARIABLE_NAME,
shape: new[] { _num_units },
initializer: tf.zeros_initializer);

built = true;
}

protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{
// Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new[] { inputs, state }, 1);
var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable);
gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable);
var output = _activation(gate_inputs, null);
return new[] { output, output };
}
} }
} }

+ 73
- 25
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -19,6 +19,8 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Operations.ControlFlows; using Tensorflow.Operations.ControlFlows;
using static Tensorflow.ControlFlowContextDef; using static Tensorflow.ControlFlowContextDef;
using static Tensorflow.Binding;
using util = Tensorflow.control_flow_util;


namespace Tensorflow.Operations namespace Tensorflow.Operations
{ {
@@ -72,6 +74,7 @@ namespace Tensorflow.Operations
public ControlFlowContext() public ControlFlowContext()
{ {
_context_stack = new Stack<ControlFlowContext>(); _context_stack = new Stack<ControlFlowContext>();
_external_values = new Dictionary<string, ITensorOrOperation>();
} }


public string name { get => _name; } public string name { get => _name; }
@@ -134,27 +137,6 @@ namespace Tensorflow.Operations
graph._set_control_flow_context(this); graph._set_control_flow_context(this);
} }


protected virtual Tensor _Enter(Tensor data, string frame_name,
bool is_constant = false,
int parallel_iterations = 10,
bool use_ref = true,
bool use_input_shape = true,
string name = null)
{
Tensor result;
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true);
if (data.dtype.is_ref_dtype() && use_ref)
throw new NotImplementedException("_Enter");
else
result = gen_control_flow_ops.enter(
data, frame_name, is_constant, parallel_iterations, name: name);

if (use_input_shape)
result.set_shape(data.TensorShape);

return result;
}

/// <summary> /// <summary>
/// Exit this control flow context. /// Exit this control flow context.
/// </summary> /// </summary>
@@ -165,10 +147,18 @@ namespace Tensorflow.Operations
graph._set_control_flow_context(last_context); graph._set_control_flow_context(last_context);
} }


public void ExitResult(Tensor[] result)
{
if(_outer_context != null)
{
throw new NotImplementedException("ExitResult");
}
}

/// <summary> /// <summary>
/// Add `op` to the current context. /// Add `op` to the current context.
/// </summary> /// </summary>
public void AddOp(Operation op)
public virtual void AddOp(Operation op)
{ {
_AddOpInternal(op); _AddOpInternal(op);
} }
@@ -180,12 +170,22 @@ namespace Tensorflow.Operations


public virtual bool back_prop => throw new NotImplementedException("abstract method"); public virtual bool back_prop => throw new NotImplementedException("abstract method");


/// <summary>
/// Add `val` to the current context and its outer context recursively.
/// </summary>
/// <param name="val"></param>
/// <returns></returns>
public virtual Tensor AddValue(Tensor val) public virtual Tensor AddValue(Tensor val)
{ {
// to be overridden // to be overridden
return null; return null;
} }


public void AddName(string name)
{
_values.Add(name);
}

/// <summary> /// <summary>
/// Notifies a scope about an operator added to an inner scope. /// Notifies a scope about an operator added to an inner scope.
/// </summary> /// </summary>
@@ -203,7 +203,20 @@ namespace Tensorflow.Operations
/// </summary> /// </summary>
protected virtual void _AddOpInternal(Operation op) protected virtual void _AddOpInternal(Operation op)
{ {
if(op == null)
{
throw new NotImplementedException("");
}
else
{
foreach(var index in range(len(op.inputs)))
{
var x = op.inputs[index];
var real_x = AddValue(x);
if (real_x != x)
op._update_input(index, real_x);
}
}
} }


protected bool OpInContext(Operation op) protected bool OpInContext(Operation op)
@@ -230,9 +243,36 @@ namespace Tensorflow.Operations
throw new NotImplementedException("_IsInOuterContext"); throw new NotImplementedException("_IsInOuterContext");
} }


protected virtual void _RemoveExternalControlEdges(Operation op)
/// <summary>
/// Remove any external control dependency on this op.
/// </summary>
/// <param name="op"></param>
protected virtual (Operation[], Operation[]) _RemoveExternalControlEdges(Operation op)
{ {
var internal_control_inputs = op.control_inputs;
var while_ctxt = GetWhileContext();

var internal_control_inputs = new List<Operation>();
// A control input of `op` is internal if it is in the same while
// loop context as the enclosing while loop context of self.
if (while_ctxt == null)
{
internal_control_inputs = op.control_inputs.ToList();
}
else
{
foreach(Operation x in op.control_inputs)
{
var ctxt = util.GetOutputContext(x);
if (ctxt != null && ctxt.GetWhileContext() == while_ctxt)
internal_control_inputs.append(x);
}
}

var external_control_inputs = new List<Operation>();
if (len(internal_control_inputs) != len(op.control_inputs))
throw new NotImplementedException("");

return (internal_control_inputs.ToArray(), external_control_inputs.ToArray());
} }


/// <summary> /// <summary>
@@ -264,6 +304,14 @@ namespace Tensorflow.Operations
throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}"); throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}");
} }


public virtual bool IsWhileContext()
{
throw new NotImplementedException("IsWhileContext");
}

public virtual bool IsCondContext()
=> false;

public object to_proto() public object to_proto()
{ {
throw new NotImplementedException(); throw new NotImplementedException();


+ 197
- 162
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs View File

@@ -14,6 +14,12 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/
using System;
using System.Linq;
using System.Collections.Generic;
using util = Tensorflow.control_flow_util;
using static Tensorflow.Binding;
namespace Tensorflow.Operations.ControlFlows namespace Tensorflow.Operations.ControlFlows
{ {
/// <summary> /// <summary>
@@ -21,6 +27,7 @@ namespace Tensorflow.Operations.ControlFlows
/// </summary> /// </summary>
public class ControlFlowState public class ControlFlowState
{ {
Dictionary<ControlFlowContext, GradLoopState> _map;
//class ControlFlowState(object): //class ControlFlowState(object):
// """Maintain the mapping from the loops to their grad states.""" // """Maintain the mapping from the loops to their grad states."""
@@ -40,57 +47,74 @@ namespace Tensorflow.Operations.ControlFlows
// return self._map.get(forward_ctxt) // return self._map.get(forward_ctxt)
// return None // return None
// def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
// """Process all the "unused" loop exits.
// The "unused" exits of the loops are added to `unused_exits`. An exit is
// unused if its pending_count is 0. If there is an exit with real gradient,
// all these deferred exits will enter the backprop loop with zero gradient.
// Otherwise, they will enter the backprop loop with None. As an example,
// people often write:
// ```python
// v1, _ = tf.while_loop(p, b, [x1, x2])
// result = gradients(v1, x1)
// ```
// The exit node for x2 is not included by the betweenness analysis. But we
// need to backprop x2 if x2 is involved in computing v1.
// Args:
// pending_count: The number of backprop inputs for every op.
// to_ops_set: The set of ops for ys in gradients(ys, xs)
// Returns:
// The set of unused loop exits that we know at this point we need
// to backprop.
// """
// loop_exits = []
// for grad_state in self._map.values():
// for y in grad_state.forward_loop_exits:
// if pending_count[y.op] == 0:
// grad_state.pending_exits_count -= 1
// if y.op not in to_ops_set:
// grad_state.unused_exits.append(y)
// if grad_state.pending_exits_count == 0:
// loop_exits.extend(grad_state.unused_exits)
// # Need to include Enters in backprop for higher-order gradients.
// for y in grad_state.forward_context.loop_enters:
// if pending_count[y.op] == 0:
// pending_count[y.op] = 1
// return loop_exits
// def EnterGradWhileContext(self, op, before):
// """Enter the WhileContext for gradient computation."""
// grad_state = self.GetGradState(op, before)
// if grad_state:
// grad_state.grad_context.Enter()
// def ExitGradWhileContext(self, op, before):
// """Exit the WhileContext for gradient computation."""
// grad_state = self.GetGradState(op, before)
// if grad_state:
// grad_state.grad_context.Exit()
public ControlFlowState()
{
_map = new Dictionary<ControlFlowContext, GradLoopState>();
}
/// <summary>
/// Return the grad state for this op if it's in a forward loop context.
/// </summary>
/// <param name="op"></param>
/// <param name="before"></param>
/// <returns></returns>
public GradLoopState GetGradState(Operation op, bool before)
{
ControlFlowContext forward_ctxt = null;
if (before && util.IsLoopExit(op))
{
forward_ctxt = op._get_control_flow_context();
forward_ctxt = forward_ctxt.outer_context;
if (forward_ctxt != null)
forward_ctxt = forward_ctxt.GetWhileContext();
}
else
forward_ctxt = util.GetWhileContext(op);
if (forward_ctxt != null)
return _map.get(forward_ctxt);
return null;
}
public Tensor[] ProcessUnusedLoopExits(Dictionary<string, int> pending_count, List<Operation> to_ops_set)
{
var loop_exits = new List<Tensor>();
foreach(var grad_state in _map.Values)
{
foreach(var y in grad_state.forward_loop_exits)
{
if(!pending_count.ContainsKey(y.op.name))
{
grad_state.pending_exits_count -= 1;
if (!to_ops_set.Contains(y.op))
grad_state.unused_exits.append(y);
if (grad_state.pending_exits_count == 0)
loop_exits.extend(grad_state.unused_exits);
}
}
foreach(var y in grad_state.forward_context.loop_enters)
{
if (!pending_count.ContainsKey(y.op.name))
pending_count[y.op.name] = 1;
}
}
return loop_exits.ToArray();
}
public void EnterGradWhileContext(Operation op, bool before)
{
var grad_state = GetGradState(op, before);
if (grad_state != null)
grad_state.grad_context.Enter();
}
public void ExitGradWhileContext(Operation op, bool before)
{
var grad_state = GetGradState(op, before);
if (grad_state != null)
grad_state.grad_context.Exit();
}
// def AddWhileContext(self, op, between_op_list, between_ops): // def AddWhileContext(self, op, between_op_list, between_ops):
// """Add the grad state for the while loop that op belongs to. // """Add the grad state for the while loop that op belongs to.
@@ -118,6 +142,32 @@ namespace Tensorflow.Operations.ControlFlows
// if loop_exit.op not in between_ops: // if loop_exit.op not in between_ops:
// between_ops.add(loop_exit.op) // between_ops.add(loop_exit.op)
// between_op_list.append(loop_exit.op) // between_op_list.append(loop_exit.op)
public void AddWhileContext(Operation op, List<Operation> between_op_list, List<Operation> between_ops)
{
var forward_ctxt = op.GetWhileContext();
var grad_state = _map.ContainsKey(forward_ctxt) ? _map[forward_ctxt] : null;
if(grad_state == null)
{
GradLoopState outer_grad_state = null;
var outer_forward_ctxt = forward_ctxt.outer_context;
if (outer_forward_ctxt != null)
outer_forward_ctxt = outer_forward_ctxt.GetWhileContext();
if (outer_forward_ctxt != null)
outer_grad_state = _map[outer_forward_ctxt];
grad_state = new GradLoopState(forward_ctxt, outer_grad_state);
_map[forward_ctxt] = grad_state;
// We need to include all exits of a loop for backprop.
foreach (var loop_exit in grad_state.forward_loop_exits)
{
if(!between_ops.Contains(loop_exit.op))
{
between_ops.add(loop_exit.op);
between_op_list.append(loop_exit.op);
}
}
}
}
// def ZerosLikeForExit(self, val): // def ZerosLikeForExit(self, val):
// """Create zeros_like gradient for a loop exit. // """Create zeros_like gradient for a loop exit.
@@ -174,116 +224,101 @@ namespace Tensorflow.Operations.ControlFlows
// result = array_ops.zeros_like(val, optimize=False) // result = array_ops.zeros_like(val, optimize=False)
// return result // return result
// def ZerosLike(self, op, index):
// """Create zeros_like for the specified output of an op.
// If op is in a while loop that is part of gradients(), this method
// must be called in its grad loop context.
// Args:
// op: A tensorflow operation.
// index: the index for a specific output of the op.
// Returns:
// A zero tensor of the same shape of op.outputs[index].
// """
// if util.IsLoopSwitch(op):
// return None
// if op.graph._building_function: # pylint: disable=protected-access
// # The optimization here is tricky to apply to functions
// return array_ops.zeros_like(op.outputs[index])
// dead_branch = util.IsSwitch(op)
// forward_ctxt = _GetWhileContext(op)
// grad_state = self._map.get(forward_ctxt)
// if grad_state is None:
// # op is not in a while loop that is part of gradients().
// return ZerosLikeOutsideLoop(op, index)
// op_ctxt = op._get_control_flow_context()
// val = ops.convert_to_tensor(op.outputs[index], name="tensor")
// shape = val.get_shape()
// if shape.is_fully_defined():
// # If the shape is known statically, just create a zero tensor with
// # the right shape in the grad loop context.
// result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
// if dead_branch:
// # op is a cond switch. Guard the zero tensor with a switch.
// pred = grad_state.history_map.get(op_ctxt.pred.name)
// branch = op_ctxt.branch
// result = _SwitchRefOrTensor(result, pred)[1 - branch]
// else:
// # Unknown shape so keep a history of the shape at runtime.
// if dead_branch:
// # Need to add a special switch to guard the value.
// pred = op_ctxt.pred
// branch = op_ctxt.branch
// op_ctxt.outer_context.Enter()
// val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
// zeros_shape = array_ops.shape_internal(val, optimize=False)
// op_ctxt.outer_context.Exit()
// val.op._set_control_flow_context(op_ctxt)
// zeros_shape.op._set_control_flow_context(op_ctxt)
// else:
// op_ctxt.Enter()
// zeros_shape = array_ops.shape_internal(val, optimize=False)
// op_ctxt.Exit()
// # Add forward accumulator for shape.
// grad_state.grad_context.Exit()
// history_zeros_shape = grad_state.AddForwardAccumulator(
// zeros_shape, dead_branch=dead_branch)
// grad_state.grad_context.Enter()
// # Create a zero tensor with the right shape.
// shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
// zeros_shape, dead_branch)
// result = array_ops.zeros(shape, val.dtype)
// return result
// def PostProcessing(self):
// """Perform postprocessing at the end of gradients().
// We have created the gradient graph at this point. So this function
// can be used to perform any postprocessing on the gradient graph.
// We currently perform the following postprocessing:
// 1. Patch the gradient graph if the output of a loop variable
// doesn't depend on its input.
// """
// for _, grad_state in self._map.items():
// for _, b_merge in grad_state.switch_map.items():
// if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
// # The value of this loop variable at iteration i+1 doesn't
// # depend on its value at iteration i. So use zeros as the
// # gradients for all iterations > 0.
// dtype = b_merge.op.inputs[0].dtype
// shape = b_merge.op.inputs[0].get_shape()
// # pylint: disable=protected-access
// if shape.is_fully_defined():
// grad_state.grad_context.Enter()
// # Create a zeros and use it for iterations > 0.
// grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
// next_grad_val = _NextIteration(grad_val)
// grad_state.grad_context.Exit()
// else:
// # Create a zeros in the outer grad context.
// outer_grad_ctxt = grad_state.grad_context.outer_context
// if outer_grad_ctxt:
// outer_grad_ctxt.Enter()
// enter_grad_op = b_merge.op.inputs[0].op
// enter_grad = enter_grad_op.inputs[0]
// grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
// grad_val = array_ops.zeros(grad_shape)
// if outer_grad_ctxt:
// outer_grad_ctxt.Exit()
// # Use the zeros for iterations > 0.
// grad_state.grad_context.Enter()
// next_grad_val = _NextIteration(grad_val)
// grad_state.grad_context.Exit()
// b_merge.op._update_input(1, next_grad_val)
// # pylint: enable=protected-access
public Tensor ZerosLike(Operation op, int index)
{
if (util.IsLoopSwitch(op))
return null;
if (op.graph.building_function)
return array_ops.zeros_like(op.outputs[index]);
var dead_branch = util.IsSwitch(op);
var forward_ctxt = util.GetWhileContext(op);
var grad_state = _map.get(forward_ctxt);
// op is not in a while loop that is part of gradients().
if (grad_state == null)
return ZerosLikeOutsideLoop(op, index);
throw new NotImplementedException("ZerosLike");
}
public Tensor ZerosLikeOutsideLoop(Operation op, int index)
{
var val = op.outputs[index];
if (!util.IsSwitch(op))
{
if (val.dtype == dtypes.resource)
throw new NotImplementedException("ZerosLikeOutsideLoop");
/*return array_ops.zeros(
gen_resource_variable_ops.variable_shape(val),
dtype: default_gradient.get_zeros_dtype(val));*/
return array_ops.zeros_like(val, optimize: false);
}
else
throw new NotImplementedException("ZerosLikeOutsideLoop");
}
/// <summary>
/// Create zeros_like gradient for a loop exit.
/// </summary>
/// <param name="val"></param>
/// <returns></returns>
public Tensor ZerosLikeForExit(Tensor val)
{
Tensor result = null;
var val_shape = val.TensorShape;
var forward_ctxt = val.op._get_control_flow_context();
var outer_forward_ctxt = forward_ctxt.outer_context;
if (outer_forward_ctxt != null)
outer_forward_ctxt = outer_forward_ctxt.GetWhileContext();
GradLoopState outer_grad_state = null;
if (outer_forward_ctxt != null)
outer_grad_state = _map.get(outer_forward_ctxt);
// This is a nested loop.
if (outer_grad_state != null)
{
throw new NotImplementedException("ZerosLikeForExit");
}
else
{
// If the shape is known statically, just create a zero tensor
// with the right shape.
if (val_shape.is_fully_defined())
result = array_ops.zeros(val_shape.dims, val.dtype);
else
result = array_ops.zeros_like(val, optimize: false);
}
return result;
}
public void PostProcessing()
{
foreach(var grad_state in _map.Values)
{
foreach(var b_merge in grad_state.switch_map.Values)
{
if(b_merge.op.inputs[0] == b_merge.op.inputs[1])
{
Tensor next_grad_val = null;
// The value of this loop variable at iteration i+1 doesn't
// depend on its value at iteration i. So use zeros as the
// gradients for all iterations > 0.
var dtype = b_merge.op.inputs[0].dtype;
var shape = b_merge.op.inputs[0].TensorShape;
if (shape.is_fully_defined())
{
grad_state.grad_context.Enter();
// Create a zeros and use it for iterations > 0.
var grad_val = constant_op.constant(0, dtype: dtype, shape: shape);
next_grad_val = control_flow_ops._NextIteration(grad_val);
grad_state.grad_context.Exit();
}
else
{
throw new NotImplementedException("PostProcessing shape is not fully defined.");
}
b_merge.op._update_input(1, next_grad_val);
}
}
}
}
} }
} }

+ 240
- 317
src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs View File

@@ -16,41 +16,18 @@
using System; using System;
using System.Collections; using System.Collections;
using System.Collections.Generic;
using System.Linq;
using static Tensorflow.Binding;
using util = Tensorflow.control_flow_util;
namespace Tensorflow.Operations.ControlFlows namespace Tensorflow.Operations.ControlFlows
{ {
/// <summary>
/// The state used for constructing the gradient graph for a while loop.
/// </summary>
public class GradLoopState public class GradLoopState
{ {
//class GradLoopState(object):
// """The state used for constructing the gradient graph for a while loop.
// We create a GradLoopState for each while loop in forward and its
// corresponding while loop in backprop. This gives us access to both
// the forward and the backprop WhileContexts.
// During the construction of gradient graph, any time when we detect
// a forward value that is needed for backprop, we create a history
// accumulator and add it to `history_map`. Any time when we backprop
// a loop switch op (in _SwitchGrad), we add the grad merge op in
// `switch_map`.
// """
// def __init__(self, forward_ctxt, outer_grad_state):
// # The grad loop state for the outer while loop.
// self._outer_grad_state = None
// # The while loop context for forward.
// self._forward_context = None
// # The loop counter added by AddForwardLoopCounter. It is the value
// # of the loop counter for the next iteration.
// self._forward_index = None
// # A sync op for forward.
// self._forward_sync = None
// # The while loop context for backprop.
private WhileContext _grad_context = null; private WhileContext _grad_context = null;
public WhileContext grad_context => _grad_context; public WhileContext grad_context => _grad_context;
@@ -65,156 +42,112 @@ namespace Tensorflow.Operations.ControlFlows
// # Information needed by backprop. // # Information needed by backprop.
private Hashtable _history_map = new Hashtable(); private Hashtable _history_map = new Hashtable();
public Hashtable history_map => _history_map; public Hashtable history_map => _history_map;
private Hashtable _switch_map = new Hashtable();
public Hashtable switch_map => _switch_map;
// self._unused_exits = []
// self._deferred_exits = []
// self._forward_loop_exits = list(forward_ctxt.loop_exits)
// self._pending_exits_count = len(forward_ctxt.loop_exits)
// self._outer_grad_state = outer_grad_state
// if outer_grad_state:
// outer_forward_ctxt = outer_grad_state.forward_context
// else:
// if not hasattr(forward_ctxt, "outer_context"):
// raise ValueError("Failed to call gradients on a while loop without"
// "properly serializing graph via MetaGraphDef")
// outer_forward_ctxt = forward_ctxt.outer_context
// # Add the forward loop counter.
// with forward_ctxt._graph.as_default(): # pylint: disable=protected-access
// if outer_forward_ctxt:
// outer_forward_ctxt.Enter()
// cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
// if outer_forward_ctxt:
// outer_forward_ctxt.Exit()
// self._forward_context = forward_ctxt
// self._forward_index = forward_index
// # Add the backprop WhileContext, and the backprop loop counter.
// if outer_grad_state:
// # This is a nested loop. Remember the iteration counts for each
// # execution of this inner loop.
// outer_forward_ctxt.AddName(cnt.name)
// history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
// outer_grad_ctxt = outer_grad_state.grad_context
// outer_grad_ctxt.Enter()
// self._grad_context = WhileContext(
// maximum_iterations=forward_ctxt.maximum_iterations,
// parallel_iterations=forward_ctxt.parallel_iterations,
// back_prop=forward_ctxt.back_prop,
// swap_memory=forward_ctxt.swap_memory,
// name=forward_ctxt.name,
// grad_state=self)
// real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
// self._grad_index = self._grad_context.AddBackpropLoopCounter(
// real_cnt, outer_grad_state)
// outer_grad_ctxt.Exit()
// else:
// if outer_forward_ctxt:
// outer_forward_ctxt.Enter()
// self._grad_context = WhileContext(
// maximum_iterations=forward_ctxt.maximum_iterations,
// parallel_iterations=forward_ctxt.parallel_iterations,
// back_prop=forward_ctxt.back_prop,
// swap_memory=forward_ctxt.swap_memory,
// name=forward_ctxt.name,
// grad_state=self)
// self._grad_index = self._grad_context.AddBackpropLoopCounter(
// cnt, outer_grad_state)
// if outer_forward_ctxt:
// outer_forward_ctxt.Exit()
// @property
// def outer_grad_state(self):
// """The grad loop state for outer loop."""
// return self._outer_grad_state
Dictionary<Operation, Tensor> _switch_map = new Dictionary<Operation, Tensor>();
public Dictionary<Operation, Tensor> switch_map => _switch_map;
// @property
// def forward_context(self):
// """The while loop context for forward."""
// return self._forward_context
// @property
// def forward_index(self):
// """The loop index of forward loop."""
// return self._forward_index
// @property
// def forward_sync(self):
// """A control trigger node for synchronization in the forward loop.
// One main use is to keep the push ops of a stack executed in the
// iteration order.
// """
// if self._forward_sync is None:
// with ops.control_dependencies(None):
// self._forward_sync = control_trigger(name="f_sync")
// self._forward_sync._set_control_flow_context(self._forward_context)
// self._forward_index.op._add_control_input(self._forward_sync)
// return self._forward_sync
// @property
// def grad_context(self):
// """The corresponding WhileContext for gradient."""
// return self._grad_context
// @property
// def grad_index(self):
// """The loop index of backprop loop."""
// return self._grad_index
// @property
// def grad_sync(self):
// """A control trigger node for synchronization in the grad loop.
/// <summary>
/// The while loop context for forward.
/// </summary>
WhileContext _forward_context;
public WhileContext forward_context => _forward_context;
// One main use is to keep the pop ops of a stack executed in the
// iteration order.
// """
// if self._grad_sync is None:
// with ops.control_dependencies(None):
// self._grad_sync = control_trigger(name="b_sync")
// self._grad_sync._set_control_flow_context(self._grad_context)
// self._grad_index.op._add_control_input(self._grad_sync)
// if self._grad_context.outer_context:
// self._grad_context.outer_context.AddInnerOp(self._grad_sync)
// return self._grad_sync
/// <summary>
/// The grad loop state for the outer while loop.
/// </summary>
GradLoopState _outer_grad_state;
public GradLoopState outer_grad_state => _outer_grad_state;
// @property
// def history_map(self):
// """The map that records all the tensors needed for backprop."""
// return self._history_map
Tensor _forward_index;
public Tensor forward_index => _forward_index;
Tensor _grad_index;
// @property
// def switch_map(self):
// """The map that records all the Switch ops for the while loop."""
// return self._switch_map
Tensor[] _forward_loop_exits;
/// <summary>
/// The list of exits of the forward loop.
/// </summary>
public Tensor[] forward_loop_exits => _forward_loop_exits;
// @property
// def unused_exits(self):
// """The list of "unused" exits."""
// return self._unused_exits
List<Tensor> _deferred_exits;
public List<Tensor> deferred_exits => _deferred_exits;
// @property
// def deferred_exits(self):
// """The list of "deferred" exits."""
// return self._deferred_exits
List<Tensor> _unused_exits;
public List<Tensor> unused_exits => _unused_exits;
// @property
// def forward_loop_exits(self):
// """The list of exits of the forward loop."""
// return self._forward_loop_exits
/// <summary>
/// The number of exits we expect to see but haven't.
/// </summary>
public int pending_exits_count { get; set; }
// @property
// def pending_exits_count(self):
// """The number of exits we expect to see but haven't."""
// return self._pending_exits_count
Operation _grad_sync;
public Operation grad_sync
{
get
{
if(_grad_sync == null)
{
tf_with(ops.control_dependencies(null), delegate
{
_grad_sync = gen_control_flow_ops.control_trigger(name: "b_sync");
});
_grad_sync._set_control_flow_context(_grad_context);
_grad_index.op._add_control_input(_grad_sync);
if (_grad_context.outer_context != null)
_grad_context.outer_context.AddInnerOp(_grad_sync);
}
return _grad_sync;
}
}
// @pending_exits_count.setter
// def pending_exits_count(self, cnt):
// """Set the pending count to cnt."""
// self._pending_exits_count = cnt
public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
{
// Information needed by backprop.
_unused_exits = new List<Tensor>();
_deferred_exits = new List<Tensor>();
_forward_loop_exits = list(forward_ctxt.loop_exits);
pending_exits_count = len(forward_ctxt.loop_exits);
_outer_grad_state = outer_grad_state_;
ControlFlowContext outer_forward_ctxt = null;
if (outer_grad_state_ != null)
outer_forward_ctxt = outer_grad_state_.forward_context;
// Add the forward loop counter.
// with forward_ctxt._graph.as_default():
Tensor cnt, forward_index;
{
if (outer_forward_ctxt != null)
outer_forward_ctxt.Enter();
(cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state);
if (outer_forward_ctxt != null)
outer_forward_ctxt.Exit();
}
_forward_context = forward_ctxt;
_forward_index = forward_index;
// Add the backprop WhileContext, and the backprop loop counter.
if (outer_grad_state != null)
{
// This is a nested loop. Remember the iteration counts for each
// execution of this inner loop.
throw new NotImplementedException("GradLoopState");
}
else
{
if (outer_forward_ctxt != null)
outer_forward_ctxt.Enter();
_grad_context = new WhileContext(
maximum_iterations: forward_ctxt.maximum_iterations,
parallel_iterations: forward_ctxt.parallel_iterations,
back_prop: forward_ctxt.back_prop,
swap_memory: forward_ctxt.swap_memory,
name: forward_ctxt.name,
grad_state: this);
_grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state);
if (outer_forward_ctxt != null)
outer_forward_ctxt.Exit();
}
}
/// <summary> /// <summary>
/// Add an accumulator for each forward tensor that is needed in backprop. /// Add an accumulator for each forward tensor that is needed in backprop.
@@ -242,63 +175,52 @@ namespace Tensorflow.Operations.ControlFlows
/// <returns>The stack that contains the accumulated history of the tensor.</returns> /// <returns>The stack that contains the accumulated history of the tensor.</returns>
public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
{ {
throw new NotImplementedException("AddForwardAccumulator");
// # curr_ctxt is the context that tf.gradients was called in.
// with self._forward_index.graph.as_default():
// curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
// with ops.control_dependencies(None):
// if curr_ctxt:
// curr_ctxt.Enter()
// with ops.colocate_with(value):
// # We only need to pass maximum_iterations to the stack if
// # we're inside an XLA context.
// if not util.IsInXLAContext(value.op):
// max_size = constant_op.constant(-1, dtypes.int32)
// else:
// max_size = GetMaxSizeFromNestedMaximumIterations(
// value, self.forward_context)
// acc = gen_data_flow_ops.stack_v2(
// max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
// if curr_ctxt:
// curr_ctxt.Exit()
// # Make acc available in the forward context.
// enter_acc = self.forward_context.AddValue(acc)
// # Add the stack_push op in the context of value.op.
// swap_enabled = self.forward_context.swap_memory
// value_ctxt = util.GetOutputContext(value.op)
// if value_ctxt == self.forward_context:
// # value is not nested in the forward context.
// self.forward_context.Enter()
// push = gen_data_flow_ops.stack_push_v2(
// enter_acc, value, swap_memory=swap_enabled)
// self.forward_context.Exit()
// # Protect stack push and order it before forward_index.
// self.forward_index.op._add_control_input(push.op)
// else:
// # value is in a cond context within the forward context.
// if not isinstance(value_ctxt, CondContext):
// raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
// if dead_branch:
// # The special case for creating a zero tensor for a dead
// # branch of a switch. See ControlFlowState.ZerosLike().
// value_ctxt.outer_context.Enter()
// push = gen_data_flow_ops.stack_push_v2(
// enter_acc, value, swap_memory=swap_enabled)
// value_ctxt.outer_context.Exit()
// push.op._set_control_flow_context(value_ctxt)
// else:
// value_ctxt.Enter()
// push = gen_data_flow_ops.stack_push_v2(
// enter_acc, value, swap_memory=swap_enabled)
// value_ctxt.Exit()
// # Protect stack push and order it before forward_sync.
// self.forward_sync._add_control_input(push.op)
// # Order stack push after the successor of forward_index
// add_op = self.forward_index.op.inputs[0].op
// push.op._add_control_input(add_op)
// return acc
_forward_index.graph.as_default();
{
var curr_ctxt = ops.get_default_graph()._get_control_flow_context();
return tf_with(ops.control_dependencies(null), delegate
{
Tensor acc = null;
Tensor push = null;
if (curr_ctxt != null)
curr_ctxt.Enter();
ops.colocate_with(value);
{
// We only need to pass maximum_iterations to the stack if
// we're inside an XLA context.
var max_size = constant_op.constant(-1, dtypes.int32);
acc = gen_data_flow_ops.stack_v2(
max_size: max_size, elem_type: value.dtype.as_base_dtype(), name: "f_acc");
}
if (curr_ctxt != null)
curr_ctxt.Exit();
// Make acc available in the forward context.
var enter_acc = forward_context.AddValue(acc);
// Add the stack_push op in the context of value.op.
var swap_enabled = forward_context.swap_memory;
var value_ctxt = util.GetOutputContext(value.op);
if(value_ctxt == forward_context)
{
// value is not nested in the forward context.
forward_context.Enter();
push = gen_data_flow_ops.stack_push_v2(enter_acc, value, swap_memory: swap_enabled);
forward_context.Exit();
// Protect stack push and order it before forward_index.
forward_index.op._add_control_input(push.op);
}
else
{
throw new NotImplementedException("AddForwardAccumulator");
}
// Order stack push after the successor of forward_index
var add_op = forward_index.op.inputs[0].op;
push.op._add_control_input(add_op);
return acc;
});
}
} }
// """Add the getter for an accumulated value in the grad context. // """Add the getter for an accumulated value in the grad context.
@@ -315,98 +237,99 @@ namespace Tensorflow.Operations.ControlFlows
// Returns: // Returns:
// The current value (the top of the stack). // The current value (the top of the stack).
// """ // """
public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false)
{ {
throw new NotImplementedException();
// history_ctxt = history_value.op._get_control_flow_context()
// # Find the cond context that controls history_value if any.
// cond_ctxt = None
// value_ctxt = value.op._get_control_flow_context()
// while value_ctxt and value_ctxt != history_ctxt:
// if isinstance(value_ctxt, CondContext):
// cond_ctxt = value_ctxt
// break
// value_ctxt = value_ctxt.outer_context
// with ops.control_dependencies(None):
// self.grad_context.Enter()
// if cond_ctxt:
// # Guard stack pop with a switch if it is controlled by a cond.
// grad_state = self
// pred = None
// while pred is None and grad_state:
// pred = grad_state.history_map.get(cond_ctxt.pred.name)
// grad_state = grad_state.outer_grad_state
// if pred is None:
// pred = cond_ctxt.pred
// branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
// history_value = _SwitchRefOrTensor(history_value, pred)[branch]
// pop = gen_data_flow_ops.stack_pop_v2(history_value,
// value.dtype.base_dtype)
// pop.set_shape(value.get_shape())
// self.grad_context.Exit()
// parallel_iterations = self.grad_context.parallel_iterations
// if parallel_iterations > 1:
// # All pops are ordered after pivot_for_body and before grad_sync.
// self.grad_sync._add_control_input(pop.op)
// return pop
var history_ctxt = history_value.op._get_control_flow_context();
// Find the cond context that controls history_value if any.
CondContext cond_ctxt = null;
Tensor pop = null;
var value_ctxt = value.op._get_control_flow_context();
while(value_ctxt != null && value_ctxt != history_ctxt)
{
if (value_ctxt is CondContext cc)
cond_ctxt = cc;
value_ctxt = value_ctxt.outer_context;
}
tf_with(ops.control_dependencies(null), delegate
{
grad_context.Enter();
if(cond_ctxt != null)
{
throw new NotImplementedException("AddBackpropAccumulatedValue");
}
pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype());
pop.set_shape(value.TensorShape);
grad_context.Exit();
});
var parallel_iterations = grad_context.parallel_iterations;
if (parallel_iterations > 1)
// All pops are ordered after pivot_for_body and before grad_sync.
grad_sync._add_control_input(pop.op);
return pop;
} }
// def GetRealValue(self, value):
// """Get the real value of `value`.
// If backprop "uses" a value produced by forward inference, an accumulator
// is added in the forward loop to accumulate its values. We use the
// accumulated value. This method must be called in the grad loop context.
// `value` must be in forward and needed for backprop.
// Args:
// value: A tensor to be captured.
// Returns:
// The same tensor obtained from the saved history.
// """
// assert value.op.type not in ["Variable", "VariableV2"]
// real_value = self._history_map.get(value.name)
// if real_value is None:
// cur_value = value
// cur_grad_state = self
// while True:
// enter_op = util.GetLoopConstantEnter(cur_value)
// if enter_op:
// # Special case: cur_value comes from a constant Enter node.
// cur_value = enter_op.inputs[0]
// cur_grad_state = cur_grad_state.outer_grad_state
// if cur_grad_state is None:
// # We are now outside all nested loops for this gradient(),
// # so `value` is a loop invariant and there is no need to
// # save the history of value. Just make cur_value to enter
// # the right control flow context.
// real_value = self._grad_context.AddValue(cur_value)
// break
// elif constant_op.is_constant(cur_value):
// # If the value to be forwarded is a constant, clone the constant in
// # the gradient loop rather than using a stack.
// # TODO(phawkins): consider hoisting the constant out of the loop
// # instead.
// real_value = constant_op.constant(
// tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
// break
// else:
// # Record the history of this value in forward_ctxt.
// self._grad_context.Exit()
// history_value = cur_grad_state.AddForwardAccumulator(cur_value)
// self._grad_context.Enter()
// break
// if real_value is None:
// # Add the stack pop op in the grad context.
// real_value = cur_grad_state.AddBackpropAccumulatedValue(
// history_value, cur_value)
// if cur_grad_state != self:
// real_value = self._grad_context.AddValue(real_value)
// self._history_map[value.name] = real_value
// return real_value
/// <summary>
/// Get the real value of `value`.
/// </summary>
/// <param name="value">A tensor to be captured.</param>
/// <returns>The same tensor obtained from the saved history.</returns>
public Tensor GetRealValue(Tensor value)
{
Tensor real_value = null;
if(real_value == null)
{
var cur_value = value;
var cur_grad_state = this;
Tensor history_value = null;
while (true)
{
var enter_op = util.GetLoopConstantEnter(cur_value);
if(enter_op != null)
{
// Special case: cur_value comes from a constant Enter node.
cur_value = enter_op.inputs[0];
cur_grad_state = cur_grad_state.outer_grad_state;
if(cur_grad_state == null)
{
// We are now outside all nested loops for this gradient(),
// so `value` is a loop invariant and there is no need to
// save the history of value. Just make cur_value to enter
// the right control flow context.
real_value = _grad_context.AddValue(cur_value);
break;
}
}
else if (constant_op.is_constant(cur_value))
{
// We are now outside all nested loops for this gradient(),
// so `value` is a loop invariant and there is no need to
// save the history of value. Just make cur_value to enter
// the right control flow context.
real_value = constant_op.constant(
tensor_util.constant_value(cur_value), dtype: cur_value.dtype);
break;
}
else
{
// Record the history of this value in forward_ctxt.
_grad_context.Exit();
history_value = cur_grad_state.AddForwardAccumulator(cur_value);
_grad_context.Enter();
break;
}
}
if(real_value == null)
{
// Add the stack pop op in the grad context.
real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value, cur_value);
if (cur_grad_state != this)
real_value = _grad_context.AddValue(real_value);
}
_history_map[value.name] = real_value;
}
return real_value;
}
} }
} }

+ 43
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs View File

@@ -0,0 +1,43 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Operations
{
internal class LoopVar<TItem> : ICanBeFlattened, IPackable<LoopVar<TItem>>
{
public Tensor Counter { get; set; }
public TItem Item { get; set; }

public LoopVar(Tensor counter, TItem item)
{
Counter = counter;
Item = item;
}

public object[] Flatten()
{
var elements = new List<object> { Counter };
if (typeof(TItem).GetInterface(typeof(ICanBeFlattened).Name) != null)
elements.AddRange((Item as ICanBeFlattened).Flatten());
else
elements.Add(Item);
return elements.ToArray();
}

public LoopVar<TItem> Pack(object[] sequences)
{
var counter = sequences[0] as Tensor;
var item = default(TItem);
if (typeof(TItem).GetInterface(typeof(IPackable<TItem>).Name) != null)
item = (Item as IPackable<TItem>).Pack(sequences.Skip(1).ToArray());
return new LoopVar<TItem>(counter, item);
}

public static implicit operator (Tensor, TItem)(LoopVar<TItem> loopVar)
{
return (loopVar.Counter, loopVar.Item);
}
}
}

+ 36
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs View File

@@ -0,0 +1,36 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
public class MergeOutput
{
Tensor output;
Tensor value_index;
public MergeOutput(Tensor[] values)
{
output = values[0];
value_index = values[1];
}

public Tensor this[int idx]
{
get
{
switch(idx)
{
case 0:
return output;
case 1:
return value_index;
default:
return null;
}
}
}

public static implicit operator Tensor(MergeOutput merge)
=> merge.output;
}
}

+ 445
- 32
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -32,17 +32,22 @@ namespace Tensorflow.Operations
bool _back_prop=true; bool _back_prop=true;
GradLoopState _grad_state =null; GradLoopState _grad_state =null;
Tensor _maximum_iterations; Tensor _maximum_iterations;
public Tensor maximum_iterations => _maximum_iterations;
int _parallel_iterations; int _parallel_iterations;
public int parallel_iterations => _parallel_iterations;
bool _swap_memory; bool _swap_memory;
public bool swap_memory => _swap_memory;
Tensor _pivot_for_pred; Tensor _pivot_for_pred;
Tensor _pivot_for_body; Tensor _pivot_for_body;
List<Tensor> _loop_exits; List<Tensor> _loop_exits;
public List<Tensor> loop_exits => _loop_exits;
List<Tensor> _loop_enters; List<Tensor> _loop_enters;
public List<Tensor> loop_enters => _loop_enters;
Graph _graph; Graph _graph;
public override GradLoopState grad_state => _grad_state; public override GradLoopState grad_state => _grad_state;
public override bool back_prop => _back_prop; public override bool back_prop => _back_prop;


public WhileContext(int? maximum_iterations = null,
public WhileContext(Tensor maximum_iterations = null,
int parallel_iterations = 10, int parallel_iterations = 10,
bool back_prop = true, bool back_prop = true,
bool swap_memory = false, bool swap_memory = false,
@@ -64,13 +69,15 @@ namespace Tensorflow.Operations
_grad_state = grad_state; _grad_state = grad_state;
} }


private void _init_from_args(int? maximum_iterations,
private void _init_from_args(Tensor maximum_iterations,
int parallel_iterations, int parallel_iterations,
bool back_prop, bool back_prop,
bool swap_memory, bool swap_memory,
string name) string name)
{ {
_name = ops.get_default_graph().unique_name(name); _name = ops.get_default_graph().unique_name(name);
_maximum_iterations = maximum_iterations;
_parallel_iterations = parallel_iterations;
_back_prop = back_prop; _back_prop = back_prop;
_swap_memory = swap_memory; _swap_memory = swap_memory;
_loop_exits = new List<Tensor>(); _loop_exits = new List<Tensor>();
@@ -107,37 +114,75 @@ namespace Tensorflow.Operations
/// <summary> /// <summary>
/// Add the loop termination condition and body to the graph. /// Add the loop termination condition and body to the graph.
/// </summary> /// </summary>
public Tensor[] BuildLoop(Func<Tensor, Tensor> pred,
Func<Tensor, Tensor> body,
Tensor[] loop_vars,
TensorShape shape_invariants,
internal LoopVar<TItem> BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred,
Func<LoopVar<TItem>, LoopVar<TItem>> body,
LoopVar<TItem> loop_vars,
TensorShape[] shape_invariants,
bool return_same_structure) bool return_same_structure)
{ {
// Keep original_loop_vars to identify which are TensorArrays // Keep original_loop_vars to identify which are TensorArrays
var original_loop_vars = loop_vars; var original_loop_vars = loop_vars;
// Convert TensorArrays to their flow variables // Convert TensorArrays to their flow variables
var loop_vars_tensors = nest.flatten2(loop_vars)
.Select(x => _convert_tensorarray_to_flow(x))
.ToArray();

if (shape_invariants == null)
shape_invariants = loop_vars_tensors
.Select(x => _get_shape_invariant(x as Tensor))
.ToArray();

Enter(); Enter();
var(original_body_result, exit_vars) = _BuildLoop( var(original_body_result, exit_vars) = _BuildLoop(
pred, body, original_loop_vars, loop_vars, shape_invariants);
pred, body, original_loop_vars, loop_vars_tensors, shape_invariants);
Exit(); Exit();


var flat_result = original_body_result;
var flat_result = nest.flatten2(original_body_result)
.Select(x => x as ITensorOrTensorArray)
.ToArray();


var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars);
var packed_exit_vars = nest.pack_sequence_as(
var packed_exit_vars = nest.pack_sequence_as2(
structure: original_body_result, structure: original_body_result,
flat_sequence: exit_vars_with_tensor_arrays); flat_sequence: exit_vars_with_tensor_arrays);


return packed_exit_vars as Tensor[];
return packed_exit_vars;
} }


private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred,
Func<Tensor, Tensor> body,
Tensor[] original_loop_vars,
private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array)
{
if (tensor_or_tensor_array is TensorArray tensor_array)
return tensor_array.flow;
else if (tensor_or_tensor_array is Tensor tensor)
return tensor;

throw new NotImplementedException("_convert_tensorarray_to_flow");
}

private TensorShape _get_shape_invariant(Tensor var, int[] shape = null)
{
return var.TensorShape;
}

/// <summary>
/// Add the loop termination condition and body to the graph.
/// </summary>
/// <typeparam name="TItem"></typeparam>
/// <param name="pred"></param>
/// <param name="body"></param>
/// <param name="original_loop_vars"></param>
/// <param name="loop_vars"></param>
/// <param name="shape_invariants"></param>
/// <returns></returns>
private (LoopVar<TItem>, Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred,
Func<LoopVar<TItem>, LoopVar<TItem>> body,
LoopVar<TItem> original_loop_vars,
Tensor[] loop_vars, Tensor[] loop_vars,
TensorShape shape_invariants)
TensorShape[] shape_invariants)
{ {
var flat_loop_vars = original_loop_vars;
var flat_loop_vars = nest.flatten2(original_loop_vars)
.Select(x => (ITensorOrTensorArray)x)
.ToArray();


// Let the context know the loop variables so the loop variables // Let the context know the loop variables so the loop variables
// would be added in the outer contexts properly. // would be added in the outer contexts properly.
@@ -146,14 +191,14 @@ namespace Tensorflow.Operations
Tensor[] enter_vars = null; Tensor[] enter_vars = null;
tf_with(ops.control_dependencies(null), delegate tf_with(ops.control_dependencies(null), delegate
{ {
enter_vars = real_vars.Select(x => _Enter(x,
enter_vars = real_vars.Select(x => control_flow_ops._Enter(x,
_name, _name,
is_constant: false, is_constant: false,
parallel_iterations: _parallel_iterations, parallel_iterations: _parallel_iterations,
use_input_shape: shape_invariants == null)) use_input_shape: shape_invariants == null))
.ToArray(); .ToArray();


foreach(var x in enter_vars)
foreach (var x in enter_vars)
{ {
x.graph.prevent_feeding(x); x.graph.prevent_feeding(x);
if (_outer_context != null) if (_outer_context != null)
@@ -163,7 +208,13 @@ namespace Tensorflow.Operations


// Finds the closest enclosing non-None control pivot. // Finds the closest enclosing non-None control pivot.
var outer_context = _outer_context; var outer_context = _outer_context;
while (outer_context != null)
object control_pivot = null;
while (outer_context != null && control_pivot == null)
{

}

if (control_pivot != null)
{ {


} }
@@ -177,31 +228,42 @@ namespace Tensorflow.Operations


var merge_vars = enter_vars var merge_vars = enter_vars
.Select(x => merge(new[] { x, x })) .Select(x => merge(new[] { x, x }))
.Select(m => (Tensor)m)
.ToArray(); .ToArray();


_pivot_for_pred = merge_vars[0]; _pivot_for_pred = merge_vars[0];


// Build the graph for pred. // Build the graph for pred.
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0]));
//var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays, expand_composites: true);
var packed_vars = new LoopVar<TItem>((Tensor)merge_vars_with_tensor_arrays[0],
(TItem)(object)new BodyItemInRnnWhileLoop((Tensor)merge_vars_with_tensor_arrays[1],
new[] { (TensorArray)merge_vars_with_tensor_arrays[2] },
(Tensor)merge_vars_with_tensor_arrays[3]));
var pp = pred(packed_vars);
var c = ops.convert_to_tensor(pp);
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
.ToArray(); .ToArray();


// Build the graph for body. // Build the graph for body.
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
_pivot_for_body = vars_for_body[0];
// Convert TensorArray flow variables inside the context back into // Convert TensorArray flow variables inside the context back into
// their associated TensorArrays for calling the body. // their associated TensorArrays for calling the body.
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
var body_result = body(packed_vars_for_body[0]);
var vars_for_body_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
var packed_vars_for_body = nest.pack_sequence_as2(original_loop_vars, vars_for_body_with_tensor_arrays);
var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
var body_result = body(packed_vars_for_body);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);


// Store body_result to keep track of TensorArrays returned by body // Store body_result to keep track of TensorArrays returned by body
var original_body_result = new[] { body_result };
var original_body_result = body_result;
// Convert TensorArrays returned by body into their flow variables // Convert TensorArrays returned by body into their flow variables
var result = new[] { body_result };

var result = nest.flatten2(body_result)
.Select(x => _convert_tensorarray_to_flow(x))
.ToArray();
// result = ops.convert_n_to_tensor_or_composite(result);
var next_vars = new List<Tensor>(); var next_vars = new List<Tensor>();
foreach (var (m, v) in zip(merge_vars, result)) foreach (var (m, v) in zip(merge_vars, result))
next_vars.Add(_AddNextAndBackEdge(m, v)); next_vars.Add(_AddNextAndBackEdge(m, v));
@@ -218,20 +280,45 @@ namespace Tensorflow.Operations
private void _FixControlInputsAndContext(Tensor[] enters) private void _FixControlInputsAndContext(Tensor[] enters)
{ {
var graph = ops.get_default_graph(); var graph = ops.get_default_graph();
foreach(var e in enters)
foreach(var x in enters)
{ {
var inp_op = e.op.inputs[0].op;
var inp_op = x.op.inputs[0].op;
var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op });
var outer_control_inputs = new List<Operation>();
foreach(Operation op in control_inputs)
{
// We need to keep control inputs that are in any ancestor
// ControlFlowContext, and within outer WhileContext.
var keep_as_control_input = true;
var op_ctxt = control_flow_util.GetOutputContext(op);
var outer_ctxt = outer_context;
var outer_while_context = outer_ctxt == null ? null : outer_ctxt.GetWhileContext();
while (outer_ctxt != op_ctxt)
{
if (outer_ctxt == null || outer_ctxt == outer_while_context)
{
keep_as_control_input = false;
break;
}
outer_ctxt = outer_ctxt.outer_context;
}
if (keep_as_control_input)
outer_control_inputs.append(op);
}
// op for op in control_inputs if self._IsInOuterContext(op) // op for op in control_inputs if self._IsInOuterContext(op)
var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op))
/*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op))
.Select(x => x.op) .Select(x => x.op)
.ToArray();
e.op._set_control_flow_context(this);
e.op._add_control_inputs(outer_control_inputs);
graph._record_op_seen_by_control_dependencies(e.op);
.ToArray();*/
x.op._set_control_flow_context(this);
x.op._add_control_inputs(outer_control_inputs.ToArray());
graph._record_op_seen_by_control_dependencies(x.op);
} }
} }


/// <summary>
/// Makes the values known to this context.
/// </summary>
/// <param name="values"></param>
private void _InitializeValues(Tensor[] values) private void _InitializeValues(Tensor[] values)
{ {
_values = new HashSet<string>(); _values = new HashSet<string>();
@@ -239,6 +326,332 @@ namespace Tensorflow.Operations
_values.Add(x.name); _values.Add(x.name);
} }


protected override void _AddOpInternal(Operation op)
{
if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad")
{

}
Operation[] external_inputs = new Operation[0];
Operation[] control_inputs = new Operation[0];
if (op.inputs.Length == 0)
{
// Remove any external control dependency on this op
(control_inputs, external_inputs) = _RemoveExternalControlEdges(op);
if (control_inputs.Length == 0)
op._add_control_input(GetControlPivot().op);
foreach (var x in op.outputs)
_values.Add(x.name);
}
else
{
foreach (var index in range(len(op.inputs)))
{
var x = op.inputs[index];
var real_x = AddValue(x);
if (real_x != x)
op._update_input(index, real_x);
}

// Remove any external control dependency on this op.
(_, external_inputs) = _RemoveExternalControlEdges(op);
// Add a control dependency to prevent loop invariants from
// enabling ops that should not be executed.
_MaybeAddControlDependency(op);
foreach (Tensor x in op.outputs)
_values.Add(x.name);
}

if (external_inputs.Length > 0)
{
throw new NotImplementedException("external_inputs.Length > 0");
}

if (_outer_context != null || !IsLoopExit(op))
foreach (Tensor x in op.outputs)
op.graph.prevent_feeding(x);

if (_outer_context != null)
_outer_context.AddInnerOp(op);
}

protected void _MaybeAddControlDependency(Operation op)
{
// Determines if `op` needs a control dependency.
Func<Operation, bool> _IsOpFree = (op1) =>
{
if (op1.control_inputs.Length > 0)
return false;

if (op1.type == "SymbolicGradient")
return true;

foreach (Tensor x in op1.inputs)
if (!control_flow_util.IsLoopConstantEnter(x.op))
return false;

return true;
};

if (_IsOpFree(op))
op._add_control_input(GetControlPivot().op);
}

private Tensor GetControlPivot()
{
if (_pivot_for_body != null)
return _pivot_for_body;
return _pivot_for_pred;
}

public override void AddOp(Operation op)
{
_AddOpInternal(op);
}

/// <summary>
/// Adds a loop that counts the number of iterations.
/// </summary>
/// <param name="outer_grad_state">The outer grad state. None if not nested.</param>
/// <returns>The number of iterations taken by the forward loop and the loop index.</returns>
public (Tensor, Tensor) AddForwardLoopCounter(GradLoopState outer_grad_state)
{
var n = constant_op.constant(0, name: "f_count");
if (outer_grad_state != null)
throw new NotImplementedException("AddForwardLoopCounter");

Enter();
AddName(n.name);
var enter_n = _Enter(n,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
name: "f_count");
_loop_enters.Add(enter_n);

var m1 = merge(new[] { enter_n, enter_n });
var merge_n = m1[0];
var switch_n = @switch (merge_n, _pivot);

var index = math_ops.add(switch_n[1], 1);
var next_n = _NextIteration(index);
merge_n.op._update_input(1, next_n);

var total_iterations = exit(switch_n[0], name: "f_count");
loop_exits.append(total_iterations);
ExitResult(new[] { total_iterations });
Exit();

return (total_iterations, next_n);
}

/// <summary>
/// Add an accumulation loop for every loop invariant.
/// </summary>
/// <param name="op">The Enter op for a loop invariant.</param>
/// <param name="grad">The partial gradient of an iteration for a loop invariant.</param>
/// <returns>The gradient for a loop invariant.</returns>
public Tensor AddBackpropAccumulator(Operation op, Tensor grad)
{
Tensor acc = null;
Exit();
// Create a zeros tensor with the right shape for acc. If we don't
// know the full shape statically, we will have to get the shape
// dynamically from the forward inference. Getting the shape right
// for the zeros is only needed for the base case when the loop exits
// without running any iterations.
var shape = grad.TensorShape;
if (shape.is_fully_defined())
{
if (outer_context != null)
outer_context.Enter();
acc = constant_op.constant(0, grad.dtype, shape: shape, name: "b_acc");
if (outer_context != null)
outer_context.Exit();
}
else
{
var value = op.inputs[0];
if(outer_context is WhileContext wc)
{
// We are in a nested while loop.
var forward_ctxt = grad_state.forward_context;
forward_ctxt.outer_context.Enter();
var zeros_shape = array_ops.shape_internal(value, optimize: false);
forward_ctxt.outer_context.Exit();
var outer_grad_state = grad_state.outer_grad_state;
var history_zeros_shape = outer_grad_state.AddForwardAccumulator(zeros_shape);
outer_context.Enter();
var real_shape = outer_grad_state.AddBackpropAccumulatedValue(
history_zeros_shape, zeros_shape);
acc = array_ops.zeros(real_shape, grad.dtype);
outer_context.Exit();
}
else
{
if (outer_context != null)
outer_context.Enter();
var zeros_shape = array_ops.shape_internal(value, optimize: false);
acc = array_ops.zeros(zeros_shape, grad.dtype);
if (outer_context != null)
outer_context.Exit();
}
throw new NotImplementedException("AddBackpropAccumulator");
}

Enter();
AddName(acc.name);
var enter_acc = _Enter(
acc,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
name: "b_acc");
loop_enters.append(enter_acc);
var merge_acc = merge(new[] { enter_acc, enter_acc }, name: "b_acc")[0];

var switch_result = @switch(merge_acc, _pivot);
var (switch_acc_false, switch_acc_true) = (switch_result[0], switch_result[1]);

var add_acc = math_ops.add(switch_acc_true, grad);
var next_acc = _NextIteration(add_acc);
merge_acc.op._update_input(1, next_acc);

var result_acc = exit(switch_acc_false, name: "b_acc");
loop_exits.append(result_acc);
ExitResult(new[] { result_acc });
return result_acc;
}

/// <summary>
/// Add the backprop loop that controls the iterations.
/// </summary>
/// <param name="count">The number of iterations for backprop.</param>
/// <param name="outer_grad_state">The outer grad state. None if not nested.</param>
/// <returns>The loop index.</returns>
public Tensor AddBackpropLoopCounter(Tensor count, GradLoopState outer_grad_state)
{
Tensor one = null;
var in_separate_functions = count.graph != ops.get_default_graph();
if (in_separate_functions)
// Brings the count into this graph
count = array_ops.identity(count);
else
one = constant_op.constant(1, name: "b_count");

Enter();
AddName(count.name);
var enter_count = _Enter(
count,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
name: "b_count");
loop_enters.append(enter_count);

var merge_count = merge(new[] { enter_count, enter_count })[0];
_pivot_for_pred = merge_count;
if (in_separate_functions)
one = constant_op.constant(1, name: "b_count");
var pred = math_ops.greater_equal(merge_count, one);
_pivot = gen_control_flow_ops.loop_cond(pred, name: "b_count");
var switch_count = @switch(merge_count, _pivot);

var index = math_ops.subtract(switch_count[1], one);
_pivot_for_body = index;
var next_count = _NextIteration(index);
merge_count.op._update_input(1, next_count);

var final_zero = exit(switch_count[0], name: "b_count");
loop_exits.append(final_zero);
// Force the stack pops of i-th execution of an inner loop to be ordered
// before the pops of (i+1)-th execution of the same inner loop.
if (outer_grad_state != null)
throw new NotImplementedException("outer_grad_state");
//outer_grad_state.grad_sync._add_control_input(final_zero.op);
ExitResult(new[] { final_zero });
Exit();
return next_count;
}

/// <summary>
/// Add `val` to the current context and its outer context recursively.
/// </summary>
/// <param name="val"></param>
/// <returns></returns>
public override Tensor AddValue(Tensor val)
{
var result = val;
var new_value = !_values.Contains(val.name);
new_value &= val.op._get_control_flow_context() != this;
if (new_value)
{
_values.Add(val.name);

// If we are in a grad context and val is from its forward context,
// use GetRealValue(), which adds the logic to save the history of
// val in forward.
var grad_ctxt = ops.get_default_graph()._get_control_flow_context();
if(grad_ctxt != null)
{
grad_ctxt = grad_ctxt.GetWhileContext();
if (grad_ctxt.grad_state != null)
{
var forward_ctxt = val.op.GetWhileContext();
if (control_flow_util.IsLoopExit(val.op))
{
forward_ctxt = forward_ctxt.outer_context as WhileContext;
if (forward_ctxt != null)
forward_ctxt = forward_ctxt.GetWhileContext();
throw new NotImplementedException("control_flow_util.IsLoopExit");
}
if(forward_ctxt == grad_ctxt.grad_state.forward_context)
{
var real_val = grad_ctxt.grad_state.GetRealValue(val);
_external_values[val.name] = real_val;
return real_val;
}
}
}

if (_outer_context != null)
result = _outer_context.AddValue(val);

// Create an Enter to make `result` known to this loop context.
Tensor enter = null;
tf_with(ops.control_dependencies(null), delegate
{
enter = control_flow_ops._Enter(
result,
_name,
is_constant: true,
parallel_iterations: _parallel_iterations);
enter.graph.prevent_feeding(enter);
if (_outer_context != null)
_outer_context.AddInnerOp(enter.op);
});

// Fix the control inputs and control flow context of these enter ops.
_FixControlInputsAndContext(new[] { enter });
// Add `enter` in this context.
_values.Add(enter.name);
_external_values[val.name] = enter;
result = enter;
}
else
{
var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null;
if (actual_val != null)
result = actual_val as Tensor;
}

return result;
}

public override bool IsWhileContext()
=> true;

public override WhileContext GetWhileContext() public override WhileContext GetWhileContext()
{ {
return this; return this;


+ 1
- 0
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -16,6 +16,7 @@


using System; using System;
using System.Linq; using System.Linq;
using static Tensorflow.Binding;


namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {


+ 2
- 2
src/TensorFlowNET.Core/Operations/LayerRNNCell.cs View File

@@ -16,9 +16,9 @@


namespace Tensorflow namespace Tensorflow
{ {
public class LayerRNNCell : RNNCell
public class LayerRnnCell : RnnCell
{ {
public LayerRNNCell(bool? _reuse = null,
public LayerRnnCell(bool? _reuse = null,
string name = null, string name = null,
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse, TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse,
name: name, name: name,


+ 49
- 0
src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs View File

@@ -0,0 +1,49 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable<BodyItemInRnnWhileLoop>
{
/// <summary>
/// int32 scalar Tensor.
/// </summary>
public Tensor time { get; set; }
/// <summary>
/// List of `TensorArray`s that represent the output.
/// </summary>
public TensorArray[] output_ta_t { get; set; }
/// <summary>
/// nested tuple of vector tensors that represent the state.
/// </summary>
public Tensor state { get; set; }

public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state)
{
this.time = time;
this.output_ta_t = output_ta_t;
this.state = state;
}

public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item)
=> (item.time, item.output_ta_t, item.state);

public object[] Flatten()
{
var elements = new List<object> { time };
elements.AddRange(output_ta_t);
elements.Add(state);
return elements.ToArray();
}

public BodyItemInRnnWhileLoop Pack(object[] sequences)
{
time = sequences[0] as Tensor;
output_ta_t = new[] { sequences[1] as TensorArray };
state = sequences[2] as Tensor;

return new BodyItemInRnnWhileLoop(time, output_ta_t, state);
}
}
}

+ 21
- 1
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -244,7 +244,27 @@ namespace Tensorflow.Operations
logits logits
}); });


return _op.outputs[0];
return _op.output;
}
/// <summary>
/// Says whether the targets are in the top `K` predictions.
/// </summary>
/// <param name="predictions"></param>
/// <param name="targets"></param>
/// <param name="k"></param>
/// <param name="name"></param>
/// <returns>A `Tensor` of type `bool`.</returns>
public static Tensor in_top_kv2(Tensor predictions, Tensor targets, int k, string name = null)
{
var _op = _op_def_lib._apply_op_helper("InTopKV2", name: name, args: new
{
predictions,
targets,
k
});

return _op.output;
} }
public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null)


+ 76
- 15
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using NumSharp;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
@@ -24,12 +25,12 @@ namespace Tensorflow.Operations
{ {
internal class rnn internal class rnn
{ {
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor,
public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor,
Tensor sequence_length = null, Tensor initial_state = null, Tensor sequence_length = null, Tensor initial_state = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
{ {
tf_with(tf.variable_scope("rnn"), scope =>
return tf_with(tf.variable_scope("rnn"), scope =>
{ {
VariableScope varscope = scope; VariableScope varscope = scope;
var flat_input = nest.flatten(inputs_tensor); var flat_input = nest.flatten(inputs_tensor);
@@ -63,9 +64,12 @@ namespace Tensorflow.Operations
swap_memory: swap_memory, swap_memory: swap_memory,
sequence_length: sequence_length, sequence_length: sequence_length,
dtype: dtype); dtype: dtype);
});


throw new NotImplementedException("");
if (!time_major)
outputs = nest.map_structure(_transpose_batch_time, outputs);

return (outputs, final_state);
});
} }


/// <summary> /// <summary>
@@ -79,7 +83,7 @@ namespace Tensorflow.Operations
/// <param name="sequence_length"></param> /// <param name="sequence_length"></param>
/// <param name="dtype"></param> /// <param name="dtype"></param>
/// <returns></returns> /// <returns></returns>
private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state,
private static (Tensor, Tensor) _dynamic_rnn_loop(RnnCell cell, Tensor inputs, Tensor initial_state,
int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid) int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid)
{ {
var state = initial_state; var state = initial_state;
@@ -145,7 +149,7 @@ namespace Tensorflow.Operations
{ {
var ta = new TensorArray(dtype: dtype_, var ta = new TensorArray(dtype: dtype_,
size: time_steps, size: time_steps,
element_shape: new[] { element_shape },
element_shape: element_shape,
tensor_array_name: base_name + name); tensor_array_name: base_name + name);
return ta; return ta;
}; };
@@ -170,29 +174,86 @@ namespace Tensorflow.Operations
flat_input_i.dtype)); flat_input_i.dtype));
} }


for (int i = 0; i < input_ta.Count; i++)
input_ta = zip(input_ta, flat_input).Select(x =>
{ {
var (ta, input_) = (input_ta[0], flat_input[0]);
}
var (ta, input_) = (x.Item1, x.Item2);
return ta.unstack(input_);
}).ToList();
} }


// Make sure that we run at least 1 step, if necessary, to ensure // Make sure that we run at least 1 step, if necessary, to ensure
// the TensorArrays pick up the dynamic shape. // the TensorArrays pick up the dynamic shape.
Tensor loop_bound;
Tensor loop_bound = null;
if (in_graph_mode) if (in_graph_mode)
loop_bound = math_ops.minimum( loop_bound = math_ops.minimum(
time_steps, math_ops.maximum(1, max_sequence_length)); time_steps, math_ops.maximum(1, max_sequence_length));


/*Func<Tensor, Tensor> cond = (ctime) =>
Func<BodyItemInRnnWhileLoop, Tensor> cond = (item) =>
{ {
return null;
return item.time < loop_bound;
};

// Take a time step of the dynamic RNN.
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) =>
{
Tensor[] input_t = null;
var (time1, output_ta_t, state1) = (item.time, item.output_ta_t, item.state);
if (in_graph_mode)
{
input_t = input_ta.Select(ta => ta.read(time1)).ToArray();
// Restore some shape information
foreach (var (input_, shape) in zip(input_t, inputs_got_shape))
input_.set_shape(shape[new Slice(1)]);
}
else
{
// input_t = tuple(ta[time.numpy()] for ta in input_ta)
}

var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t);
// Keras RNN cells only accept state as list, even if it's a single tensor.
// var is_keras_rnn_cell = _is_keras_rnn_cell(cell);
Tensor[] outputs = null;
if (sequence_length != null)
throw new NotImplementedException("sequence_length != null");
else
outputs = cell.__call__(input_t_t, state: state1);

var (output, new_state) = (outputs[0], outputs[1]);
// Keras cells always wrap state as list, even if it's a single tensor.
// if(is_keras_rnn_cell && len(new_state)) == 1
// Pack state if using state tuples
outputs = nest.flatten2(output).Select(x => x as Tensor).ToArray();

output_ta_t = zip(output_ta_t, outputs).Select(x =>
{
var(ta, @out) = (x.Item1, x.Item2);
return ta.write(item.time, @out);
}).ToArray();

return new BodyItemInRnnWhileLoop(item.time + 1, output_ta_t, new_state);
}; };


control_flow_ops.while_loop(
var while_loop_result = control_flow_ops.while_loop(
cond: cond, cond: cond,
body = );*/
body: _time_step,
loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state),
parallel_iterations: parallel_iterations,
maximum_iterations: time_steps,
swap_memory: swap_memory);

(_, TensorArray[] output_final_ta, Tensor final_state) = (while_loop_result.time, while_loop_result.output_ta_t, while_loop_result.state);

// Unpack final output if not using output tuples.
var final_outputs = output_final_ta.Select(ta => ta.stack()).ToArray();
// Restore some shape information
foreach (var (output, output_size) in zip(final_outputs, flat_output_size))
{
var shape = rnn_cell_impl._concat(new[] { const_time_steps, const_batch_size }, output_size, @static: true);
output.set_shape(shape);
}


throw new NotImplementedException("");
return (final_outputs[0], final_state);
} }


private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape) private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape)


+ 31
- 2
src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs View File

@@ -20,8 +20,8 @@ namespace Tensorflow.Operations
{ {
public class rnn_cell_impl public class rnn_cell_impl
{ {
public BasicRNNCell BasicRNNCell(int num_units)
=> new BasicRNNCell(num_units);
public BasicRnnCell BasicRNNCell(int num_units)
=> new BasicRnnCell(num_units);


public static Tensor _concat(Tensor prefix, int suffix, bool @static = false) public static Tensor _concat(Tensor prefix, int suffix, bool @static = false)
{ {
@@ -53,5 +53,34 @@ namespace Tensorflow.Operations
return array_ops.concat(new[] { p, s }, 0); return array_ops.concat(new[] { p, s }, 0);
} }
} }

public static TensorShape _concat(int[] prefix, int suffix, bool @static = false)
{
var p = new TensorShape(prefix);
var p_static = prefix;
var p_tensor = p.is_fully_defined() ? constant_op.constant(p.as_list(), dtype: dtypes.int32) : null;

var s_tensor_shape = new TensorShape(suffix);
var s_static = s_tensor_shape.ndim > -1 ?
s_tensor_shape.dims :
null;
var s_tensor = s_tensor_shape.is_fully_defined() ?
constant_op.constant(s_tensor_shape.dims, dtype: dtypes.int32) :
null;

if (@static)
{
if (p_static is null) return null;
var shape = new TensorShape(p_static).concatenate(s_static);
return shape;
}
else
{
if (p is null || s_tensor is null)
throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}");
// return array_ops.concat(new[] { p_tensor, s_tensor }, 0);
throw new NotImplementedException("");
}
}
} }
} }

+ 14
- 0
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -228,6 +228,15 @@ namespace Tensorflow
output_types.AddRange(types); output_types.AddRange(types);
} }


// We add an explicit colocation constraint between
// the newly created op and any of its reference-typed inputs.
var must_colocate_inputs = zip(op_def.InputArg, inputs)
.Where(x => x.Item1.IsRef)
.Select(x => x.Item2)
.ToArray();

_MaybeColocateWith(must_colocate_inputs);

// Add Op to graph // Add Op to graph
var op = g.create_op(op_type_name, var op = g.create_op(op_type_name,
inputs.ToArray(), inputs.ToArray(),
@@ -241,6 +250,11 @@ namespace Tensorflow
}); });
} }


private void _MaybeColocateWith(ITensorOrOperation[] inputs)
{

}

private void SetAttrs(string op_type_name, private void SetAttrs(string op_type_name,
ArgDef input_arg, ArgDef input_arg,
OpDef op_def, OpDef op_def,


+ 12
- 5
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/
using Tensorflow.Operations; using Tensorflow.Operations;
using static Tensorflow.Binding;
namespace Tensorflow namespace Tensorflow
{ {
@@ -30,11 +31,8 @@ namespace Tensorflow
/// </summary> /// </summary>
public void _control_flow_post_processing() public void _control_flow_post_processing()
{ {
foreach(var input_tensor in inputs)
{
//TODO: implement below code dependency
//control_flow_util.CheckInputFromValidContext(this, input_tensor.op);
}
foreach(Tensor input_tensor in inputs)
control_flow_util.CheckInputFromValidContext(this, input_tensor.op);
if (_control_flow_context != null) if (_control_flow_context != null)
_control_flow_context.AddOp(this); _control_flow_context.AddOp(this);
@@ -54,6 +52,10 @@ namespace Tensorflow
public void _set_control_flow_context(ControlFlowContext ctx) public void _set_control_flow_context(ControlFlowContext ctx)
{ {
if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc"))
{
}
_control_flow_context = ctx; _control_flow_context = ctx;
} }
@@ -61,5 +63,10 @@ namespace Tensorflow
{ {
return _control_flow_context; return _control_flow_context;
} }
public WhileContext GetWhileContext()
{
return _control_flow_context as WhileContext;
}
} }
} }

+ 8
- 6
src/TensorFlowNET.Core/Operations/Operation.Input.cs View File

@@ -14,10 +14,12 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/
using Newtonsoft.Json;
using System; using System;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
#if SERIALIZABLE
using Newtonsoft.Json;
#endif
namespace Tensorflow namespace Tensorflow
{ {
@@ -42,14 +44,14 @@ namespace Tensorflow
[JsonIgnore] [JsonIgnore]
#endif #endif
public int NumInputs => c_api.TF_OperationNumInputs(_handle); public int NumInputs => c_api.TF_OperationNumInputs(_handle);
private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray();
private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray();
private InputList _inputs;
private InputList _inputs_val;
public InputList inputs public InputList inputs
{ {
get get
{ {
if (_inputs == null)
if (_inputs_val == null)
{ {
var retval = new Tensor[NumInputs]; var retval = new Tensor[NumInputs];
@@ -60,10 +62,10 @@ namespace Tensorflow
retval[i] = op.outputs[tf_output.index]; retval[i] = op.outputs[tf_output.index];
} }
_inputs = new InputList(retval);
_inputs_val = new InputList(retval);
} }
return _inputs;
return _inputs_val;
} }
} }


+ 13
- 8
src/TensorFlowNET.Core/Operations/Operation.Instance.cs View File

@@ -15,17 +15,14 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.Linq;
using System.Collections.Generic; using System.Collections.Generic;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
public partial class Operation public partial class Operation
{ {
// cache the mapping between managed and unmanaged op
// some data is stored in managed instance, so when
// create Operation by IntPtr, it will lost some data.
private static Dictionary<IntPtr, Operation> OpInstances = new Dictionary<IntPtr, Operation>();

/// <summary> /// <summary>
/// Get operation by handle /// Get operation by handle
/// </summary> /// </summary>
@@ -33,9 +30,17 @@ namespace Tensorflow
/// <returns></returns> /// <returns></returns>
public Operation GetOperation(IntPtr handle) public Operation GetOperation(IntPtr handle)
{ {
return OpInstances.ContainsKey(handle) ?
OpInstances[handle] :
new Operation(handle);
var nodes = tf.get_default_graph()._nodes_by_name;
foreach(var node in nodes.Values)
{
if (node is Operation op)
{
if (op == handle)
return op;
}
}

return null;
} }
} }
} }

+ 3
- 1
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -14,10 +14,12 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Newtonsoft.Json;
using System; using System;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
#if SERIALIZABLE
using Newtonsoft.Json;
#endif
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow


+ 10
- 8
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -21,8 +21,9 @@ using Newtonsoft.Json;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq;
using System.Linq;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
@@ -65,7 +66,7 @@ namespace Tensorflow
#if SERIALIZABLE #if SERIALIZABLE
[JsonIgnore] [JsonIgnore]
#endif #endif
public int _id_value;
public int _id_value { get; set; }
#if SERIALIZABLE #if SERIALIZABLE
[JsonIgnore] [JsonIgnore]
#endif #endif
@@ -77,6 +78,7 @@ namespace Tensorflow
#if SERIALIZABLE #if SERIALIZABLE
[JsonIgnore] [JsonIgnore]
#endif #endif
bool _is_stateful;
public NodeDef node_def public NodeDef node_def
{ {
get get
@@ -104,7 +106,6 @@ namespace Tensorflow
_control_flow_context = _graph._get_control_flow_context(); _control_flow_context = _graph._get_control_flow_context();


// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
OpInstances[_handle] = this;
} }


/*public Operation(Graph g, string opType, string oper_name) /*public Operation(Graph g, string opType, string oper_name)
@@ -172,16 +173,19 @@ namespace Tensorflow
} }
} }


_id_value = _graph._next_id();
// Dict mapping op name to file and line information for op colocation // Dict mapping op name to file and line information for op colocation
// context managers. // context managers.
_control_flow_context = graph._get_control_flow_context();
_control_flow_context = graph._get_control_flow_context();
// This will be set by self.inputs. // This will be set by self.inputs.
if (op_def == null) if (op_def == null)
op_def = g.GetOpDef(node_def.Op); op_def = g.GetOpDef(node_def.Op);


var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
_is_stateful = op_def.IsStateful;


// Initialize self._outputs. // Initialize self._outputs.
output_types = new TF_DataType[NumOutputs]; output_types = new TF_DataType[NumOutputs];
@@ -196,8 +200,6 @@ namespace Tensorflow


if (_handle != IntPtr.Zero) if (_handle != IntPtr.Zero)
_control_flow_post_processing(); _control_flow_post_processing();

OpInstances[_handle] = this;
} }


public void run(FeedItem[] feed_dict = null, Session session = null) public void run(FeedItem[] feed_dict = null, Session session = null)
@@ -304,7 +306,7 @@ namespace Tensorflow
var output = tensor._as_tf_output(); var output = tensor._as_tf_output();


// Reset cached inputs. // Reset cached inputs.
_inputs = null;
_inputs_val = null;
// after the c_api call next time _inputs is accessed // after the c_api call next time _inputs is accessed
// the updated inputs are reloaded from the c_api // the updated inputs are reloaded from the c_api
lock (Locks.ProcessWide) lock (Locks.ProcessWide)


+ 2
- 2
src/TensorFlowNET.Core/Operations/RNNCell.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow
/// matching structure of Tensors having shape `[batch_size].concatenate(s)` /// matching structure of Tensors having shape `[batch_size].concatenate(s)`
/// for each `s` in `self.batch_size`. /// for each `s` in `self.batch_size`.
/// </summary> /// </summary>
public abstract class RNNCell : Layers.Layer
public abstract class RnnCell : Layers.Layer
{ {
/// <summary> /// <summary>
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight /// Attribute that indicates whether the cell is a TF RNN cell, due the slight
@@ -53,7 +53,7 @@ namespace Tensorflow


public virtual int output_size { get; } public virtual int output_size { get; }


public RNNCell(bool trainable = true,
public RnnCell(bool trainable = true,
string name = null, string name = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
bool? _reuse = null) : base(trainable: trainable, bool? _reuse = null) : base(trainable: trainable,


+ 78
- 7
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -22,9 +22,10 @@ using static Tensorflow.Binding;


namespace Tensorflow.Operations namespace Tensorflow.Operations
{ {
internal class _GraphTensorArray
public class _GraphTensorArray
{ {
internal TF_DataType _dtype; internal TF_DataType _dtype;
public TF_DataType dtype => _dtype;


/// <summary> /// <summary>
/// Used to keep track of what tensors the TensorArray should be /// Used to keep track of what tensors the TensorArray should be
@@ -32,19 +33,22 @@ namespace Tensorflow.Operations
/// first tensor written to it. /// first tensor written to it.
/// </summary> /// </summary>
bool _colocate_with_first_write_call; bool _colocate_with_first_write_call;
public bool colocate_with_first_write_call => _colocate_with_first_write_call;


bool _infer_shape; bool _infer_shape;
bool _dynamic_size;
List<TensorShape> _element_shape;
public bool infer_shape => _infer_shape;
public bool _dynamic_size;
public List<TensorShape> _element_shape;


List<Tensor> _colocate_with;
public List<Tensor> _colocate_with;


internal Tensor _handle; internal Tensor _handle;
public Tensor handle => _handle;
internal Tensor _flow; internal Tensor _flow;


public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, TensorShape[] element_shape = null,
bool infer_shape = true, TensorShape element_shape = null,
bool colocate_with_first_write_call = true, string name = null) bool colocate_with_first_write_call = true, string name = null)
{ {
clear_after_read = clear_after_read ?? true; clear_after_read = clear_after_read ?? true;
@@ -68,7 +72,7 @@ namespace Tensorflow.Operations
else else
{ {
_infer_shape = true; _infer_shape = true;
_element_shape = new List<TensorShape> { };
_element_shape = new List<TensorShape> { element_shape };
} }


tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope =>
@@ -135,7 +139,7 @@ namespace Tensorflow.Operations


var ta = new TensorArray(_dtype, var ta = new TensorArray(_dtype,
infer_shape:_infer_shape, infer_shape:_infer_shape,
element_shape: _element_shape.ToArray(),
element_shape: _element_shape[0],
dynamic_size: _dynamic_size, dynamic_size: _dynamic_size,
handle: _handle, handle: _handle,
flow: flow_out, flow: flow_out,
@@ -155,5 +159,72 @@ namespace Tensorflow.Operations
{ {
_colocate_with.Add(value); _colocate_with.Add(value);
} }

public Tensor read(Tensor index, string name = null)
{
var value = gen_data_flow_ops.tensor_array_read_v3(
handle: _handle,
index: index,
flow_in: _flow,
dtype: _dtype,
name: name);

if (_element_shape != null)
value.set_shape(_element_shape[0].dims);

return value;
}

public TensorArray write(Tensor index, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
_maybe_colocate_with(value);
var flow_out = gen_data_flow_ops.tensor_array_write_v3(
handle: _handle,
index: index,
value: value,
flow_in: _flow,
name: name);

return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

private Tensor size(string name = null)
{
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name);
}

public Tensor stack(string name = null)
{
ops.colocate_with(_handle);
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
{
return gather(math_ops.range(0, size()), name: name);
});
}

public Tensor gather(Tensor indices, string name = null)
{
var element_shape = new TensorShape();

if (_element_shape.Count > 0)
element_shape = _element_shape[0];

var value = gen_data_flow_ops.tensor_array_gather_v3(
handle: _handle,
indices: indices,
flow_in: _flow,
dtype: _dtype,
name: name,
element_shape: element_shape);

//if (element_shape != null)
//value.set_shape(-1, element_shape.dims);

return value;
}
} }
} }

+ 126
- 34
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -21,6 +21,7 @@ using Tensorflow.Operations;
using Tensorflow.Operations.ControlFlows; using Tensorflow.Operations.ControlFlows;
using util = Tensorflow.control_flow_util; using util = Tensorflow.control_flow_util;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow.Util;


namespace Tensorflow namespace Tensorflow
{ {
@@ -150,27 +151,50 @@ namespace Tensorflow
/// <param name="colocate_gradients_with_ops"></param> /// <param name="colocate_gradients_with_ops"></param>
public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops)
{ {
var flag = new List<Operation>();
ControlFlowState loop_state = null; ControlFlowState loop_state = null;


foreach (var op in between_op_list)
int pos = 0;
while(pos < between_op_list.Count)
{ {
var op = between_op_list[pos];
if (IsLoopExit(op)) if (IsLoopExit(op))
{ {
if(loop_state == null)
if (loop_state == null)
{ {
loop_state = new ControlFlowState(); loop_state = new ControlFlowState();
} }
if (colocate_gradients_with_ops)
ops.colocate_with(op);
loop_state.AddWhileContext(op, between_op_list, between_ops);
} }
pos++;
} }


return loop_state; return loop_state;
} }


public static bool IsLoopExit(Operation op) public static bool IsLoopExit(Operation op)
=> op.OpType == "Exit" || op.OpType == "RefExit";

public static bool IsLoopSwitch(Operation op)
{
if(IsSwitch(op))
{
var ctxt = op._get_control_flow_context();
return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op);
}
return false;
}

public static bool IsCondSwitch(Operation op)
{ {
return op.OpType == "Exit" || op.OpType == "RefExit";
throw new NotImplementedException("IsCondSwitch");
} }


public static bool IsSwitch(Operation op)
=> op.type == "Switch" || op.type == "RefSwitch";

public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null) public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null)
{ {
return tf_with(ops.name_scope(name, "tuple", tensors), scope => return tf_with(ops.name_scope(name, "tuple", tensors), scope =>
@@ -223,15 +247,10 @@ namespace Tensorflow
//TODO: missing original code //TODO: missing original code
//if context.executing_eagerly(): //if context.executing_eagerly():
// return output_tensor // return output_tensor
var values = new List<object>();
values.AddRange(dependencies);
values.Add(output_tensor);

return tf_with(ops.name_scope(name, "control_dependency", values), scope =>
return tf_with(ops.name_scope(name, "control_dependency", new { dependencies, output_tensor }), scope =>
{ {
name = scope; name = scope;
// TODO: missing original code
//with ops.colocate_with(output_tensor):
ops.colocate_with(output_tensor);
{ {
return tf_with(ops.control_dependencies(dependencies), ctl => return tf_with(ops.control_dependencies(dependencies), ctl =>
{ {
@@ -251,12 +270,16 @@ namespace Tensorflow
return gen_array_ops.identity(data, name: name); return gen_array_ops.identity(data, name: name);
} }


public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null)
public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape[] shapes = null)
{ {
if (shapes == null) if (shapes == null)
return; return;


throw new NotImplementedException("_SetShapeInvariants");
var flat_shapes = nest.flatten2(shapes);
foreach (var (inp, var, shape) in zip(input_vars, enter_vars, flat_shapes))
{
var.set_shape(shape);
}
} }


/// <summary> /// <summary>
@@ -426,14 +449,15 @@ namespace Tensorflow
var merges = zip(res_f_flat, res_t_flat) var merges = zip(res_f_flat, res_t_flat)
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
.Select(m => (Tensor)m)
.ToArray(); .ToArray();


merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges);
var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges);


ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);


return merges[0];
return new Tensor(IntPtr.Zero);
}); });
} }


@@ -473,22 +497,29 @@ namespace Tensorflow
var res_f_flat = res_f; var res_f_flat = res_f;


var merges = zip(res_f_flat, res_t_flat) var merges = zip(res_f_flat, res_t_flat)
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
.Select(pair => merge(new [] { pair.Item1, pair.Item2 }))
.Select(m => (Tensor)m)
.ToArray(); .ToArray();


merges = _convert_flows_to_tensorarrays(orig_res_t, merges);
var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges);


ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);


return merges;
return new[] { new Tensor(IntPtr.Zero) };
}); });
} }


public static Tensor[] _convert_flows_to_tensorarrays<T>(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows)
public static ITensorOrTensorArray[] _convert_flows_to_tensorarrays(ITensorOrTensorArray[] tensors_or_tensorarrays, Tensor[] tensors_or_flows)
{ {
// zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray();
return tensors_or_flows;
return zip(tensors_or_tensorarrays, tensors_or_flows).Select(x =>
{
var (ta, t_or_flow) = (x.Item1, x.Item2);
if (ta is TensorArray ta_1)
return tensor_array_ops.build_ta_with_new_flow(ta_1, t_or_flow) as ITensorOrTensorArray;
else
return t_or_flow as ITensorOrTensorArray;
}).ToArray();
} }


/// <summary> /// <summary>
@@ -508,7 +539,7 @@ namespace Tensorflow
/// <param name="inputs">inputs: The input tensors, at most one of which is available.</param> /// <param name="inputs">inputs: The input tensors, at most one of which is available.</param>
/// <param name="name">A name for this operation (optional).</param> /// <param name="name">A name for this operation (optional).</param>
/// <returns></returns> /// <returns></returns>
public static Tensor merge(Tensor[] inputs, string name = null)
public static MergeOutput merge(Tensor[] inputs, string name = null)
{ {
if (inputs.Any(x => x == null)) if (inputs.Any(x => x == null))
throw new ValueError($"At least one of the merge inputs is null: {inputs}"); throw new ValueError($"At least one of the merge inputs is null: {inputs}");
@@ -518,7 +549,7 @@ namespace Tensorflow
inputs = inputs.Select(inp => inputs = inputs.Select(inp =>
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true))
.ToArray(); .ToArray();
return gen_control_flow_ops.merge(inputs, name)[0];
return gen_control_flow_ops.merge(inputs, name);
}); });
} }


@@ -591,18 +622,18 @@ namespace Tensorflow
/// <param name="body"></param> /// <param name="body"></param>
/// <param name="loop_vars"></param> /// <param name="loop_vars"></param>
/// <param name="i"></param> /// <param name="i"></param>
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
TensorShape shape_invariants = null,
public static TItem while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars,
TensorShape[] shape_invariants = null,
int parallel_iterations = 10, int parallel_iterations = 10,
bool back_prop = true, bool back_prop = true,
bool swap_memory = false, bool swap_memory = false,
string name = null, string name = null,
int? maximum_iterations = null,
Tensor maximum_iterations = null,
bool return_same_structure = false) bool return_same_structure = false)
{ {
tf_with(ops.name_scope(name, "while", loop_vars), scope =>
return tf_with(ops.name_scope(name, "while", loop_vars), scope =>
{ {
if (loop_vars == null || loop_vars.Length == 0)
if (loop_vars == null)
throw new ValueError("No loop variables provided"); throw new ValueError("No loop variables provided");
if (cond == null) if (cond == null)
throw new ValueError("cond must be callable."); throw new ValueError("cond must be callable.");
@@ -611,6 +642,38 @@ namespace Tensorflow
if (parallel_iterations < 1) if (parallel_iterations < 1)
throw new ValueError("parallel_iterations must be a positive integer."); throw new ValueError("parallel_iterations must be a positive integer.");


var try_to_pack = loop_vars is Tensor && !return_same_structure;
var counter = constant_op.constant(0, dtype: maximum_iterations.dtype, name: "iteration_counter");
var orig_cond = cond;
var orig_body = body;

LoopVar<TItem> loop_vars_1 = null;
Func<LoopVar<TItem>, LoopVar<TItem>> body_buildloop = null;
Func<LoopVar<TItem>, Tensor> cond_buildloop = null;

if (try_to_pack)
{

}
else
{
loop_vars_1 = new LoopVar<TItem>(counter, loop_vars);
cond_buildloop = (item) =>
{
var (i, lv) = (item.Counter, item.Item);
var oc = orig_cond(lv);
return math_ops.logical_and(i < maximum_iterations, oc);
};

body_buildloop = (item) =>
{
var (i, lv) = (item.Counter, item.Item);
var ob = orig_body(lv);
return new LoopVar<TItem>(i + 1, ob);
};
}
try_to_pack = false;

var loop_context = new WhileContext( var loop_context = new WhileContext(
maximum_iterations: maximum_iterations, maximum_iterations: maximum_iterations,
parallel_iterations: parallel_iterations, parallel_iterations: parallel_iterations,
@@ -620,17 +683,46 @@ namespace Tensorflow
if (loop_context.outer_context == null) if (loop_context.outer_context == null)
ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context);


var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants,
return_same_structure); return_same_structure);


if (maximum_iterations != null)
return results[1];
else
return results[0];
//if (maximum_iterations != null)
return results.Item;
//else
//return results;
}); });

throw new NotImplementedException("while_loop");
} }


/// <summary>
/// Creates or finds a child frame, and makes `data` available to it.
/// </summary>
/// <param name="data"></param>
/// <param name="frame_name"></param>
/// <param name="is_constant"></param>
/// <param name="parallel_iterations"></param>
/// <param name="use_ref"></param>
/// <param name="use_input_shape"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor _Enter(Tensor data, string frame_name,
bool is_constant = false,
int parallel_iterations = 10,
bool use_ref = true,
bool use_input_shape = true,
string name = null)
{
Tensor result;
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true);
if (data.dtype.is_ref_dtype() && use_ref)
throw new NotImplementedException("_Enter");
else
result = gen_control_flow_ops.enter(
data, frame_name, is_constant, parallel_iterations, name: name);

if (use_input_shape)
result.set_shape(data.TensorShape);

return result;
}
} }
} }

+ 133
- 0
src/TensorFlowNET.Core/Operations/control_flow_util.py.cs View File

@@ -14,7 +14,10 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System;
using System.Linq;
using Tensorflow.Operations; using Tensorflow.Operations;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
@@ -28,6 +31,26 @@ namespace Tensorflow
public static bool IsLoopExit(Operation op) public static bool IsLoopExit(Operation op)
{ {
return op.type == "Exit" || op.type == "RefExit"; return op.type == "Exit" || op.type == "RefExit";
}
/// <summary>
/// Returns true if `op` is an Enter.
/// </summary>
/// <param name="op"></param>
/// <returns></returns>
public static bool IsLoopEnter(Operation op)
{
return op.type == "Enter" || op.type == "RefEnter";
}

/// <summary>
/// Return true iff op is a loop invariant.
/// </summary>
/// <param name="op"></param>
/// <returns></returns>
public static bool IsLoopConstantEnter(Operation op)
{
return IsLoopEnter(op) && op.get_attr<bool>("is_constant");
} }


/// <summary> /// <summary>
@@ -38,6 +61,45 @@ namespace Tensorflow
public static bool IsSwitch(Operation op) public static bool IsSwitch(Operation op)
{ {
return op.type == "Switch" || op.type == "RefSwitch"; return op.type == "Switch" || op.type == "RefSwitch";
}
public static WhileContext GetWhileContext(Operation op)
=> op.GetWhileContext();

public static bool IsCondSwitch(Operation op)
{
if (!IsSwitch(op))
return false;
if (op.outputs == null || op.outputs.Length == 0)
return false;
// Switch nodes are not part of the cond control flow context that they
// represent, so consider the consumers of its outputs to determine if it is
// cond switch or not. A switch is a cond switch iff all its consumers are in
// cond contexts.
var is_cond_switch = true;
foreach(var o in op.outputs)
{
foreach(var c in o.consumers())
{
var ctxt = c._get_control_flow_context();
if (IsLoopEnter(c))
ctxt = ctxt.outer_context;
is_cond_switch = is_cond_switch &&(ctxt != null && ctxt.IsCondContext());
}
}
return is_cond_switch;
}

public static bool IsLoopSwitch(Operation op)
{
if (IsSwitch(op))
{
var ctxt = op._get_control_flow_context();
return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op);
}
return false;
} }


/// <summary> /// <summary>
@@ -53,5 +115,76 @@ namespace Tensorflow
ctxt = ctxt.outer_context; ctxt = ctxt.outer_context;
return ctxt; return ctxt;
} }

public static void CheckInputFromValidContext(Operation op, Operation input_op)
{
var op_ctxt = op._get_control_flow_context();
var input_ctxt = GetOutputContext(input_op);
var valid = false;
if (input_ctxt == null)
valid = true;
else if (op_ctxt == input_ctxt)
valid = true;
else
{
var while_ctxt = GetContainingWhileContext(op_ctxt);
var input_while_ctxt = GetContainingWhileContext(input_ctxt);
if (while_ctxt == null)
{
throw new NotImplementedException("CheckInputFromValidContext");
}
else if (IsContainingContext(while_ctxt, input_while_ctxt))
{
// input_op is in a while loop which contains op's while loop (or not in a
// while loop at all).
valid = true;
}
else if (while_ctxt.grad_state != null &&
IsContainingContext(while_ctxt.grad_state.forward_context,
input_while_ctxt))
{
valid = true;
}
else
throw new NotImplementedException("CheckInputFromValidContext");
}
if (!valid)
{
throw new NotImplementedException("CheckInputFromValidContext");
}
}
public static Operation GetLoopConstantEnter(Tensor value)
{
var id_ops = new string[] { "Switch", "RefSwitch", "Identity", "RefIdentity" };
var op = value.op;
while (id_ops.Contains(op.type))
op = op.inputs[0].op;
return IsLoopConstantEnter(op) ? op : null;
}

public static bool IsContainingContext(WhileContext ctxt, WhileContext maybe_containing_ctxt)
{
while(ctxt != maybe_containing_ctxt)
{
if (ctxt == null)
return false;
ctxt = ctxt.outer_context as WhileContext;
}
return true;
}

public static WhileContext GetContainingWhileContext(ControlFlowContext ctxt, ControlFlowContext stop_ctxt = null)
{
while (ctxt != null)
{
if (ctxt.IsWhileContext() || ctxt == stop_ctxt)
return ctxt as WhileContext;
ctxt = ctxt.outer_context;
}
return null;
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -183,7 +183,7 @@ namespace Tensorflow
{ {
var _op = _op_def_lib._apply_op_helper("Identity", name, new { input }); var _op = _op_def_lib._apply_op_helper("Identity", name, new { input });


return _op.outputs[0];
return _op.output;
} }


public static Tensor invert_permutation(Tensor x, string name = null) public static Tensor invert_permutation(Tensor x, string name = null)


+ 15
- 4
src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs View File

@@ -14,12 +14,23 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Tensorflow.Operations;

namespace Tensorflow namespace Tensorflow
{ {
public class gen_control_flow_ops public class gen_control_flow_ops
{ {
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); public static OpDefLibrary _op_def_lib = new OpDefLibrary();


public static Operation control_trigger(string name = null)
{
var _op = _op_def_lib._apply_op_helper("ControlTrigger", name, new
{
});

return _op;
}

/// <summary> /// <summary>
/// Creates or finds a child frame, and makes `data` available to the child frame. /// Creates or finds a child frame, and makes `data` available to the child frame.
/// </summary> /// </summary>
@@ -148,18 +159,18 @@ namespace Tensorflow
return new []{_op.outputs[0], _op.outputs[1]}; return new []{_op.outputs[0], _op.outputs[1]};
} }


public static Tensor[] ref_merge(Tensor[] inputs, string name = null)
public static MergeOutput ref_merge(Tensor[] inputs, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs });


return _op.outputs;
return new MergeOutput(_op.outputs);
} }


public static Tensor[] merge(Tensor[] inputs, string name = null)
public static MergeOutput merge(Tensor[] inputs, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs });


return _op.outputs;
return new MergeOutput(_op.outputs);
} }
} }
} }

+ 100
- 5
src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs View File

@@ -28,12 +28,9 @@ namespace Tensorflow
} }


public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid, public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid,
TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null)
TensorShape element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
bool identical_element_shapes = false, string tensor_array_name = "", string name = null)
{ {
if (tensor_array_name == null)
tensor_array_name = string.Empty;

var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new
{ {
size, size,
@@ -201,5 +198,103 @@ namespace Tensorflow


return _op.outputs; return _op.outputs;
} }

/// <summary>
/// Read an element from the TensorArray into output `value`.
/// </summary>
/// <param name="handle"></param>
/// <param name="index"></param>
/// <param name="flow_in"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor tensor_array_read_v3(Tensor handle, Tensor index, Tensor flow_in, TF_DataType dtype, string name = null)
{
var _op = _op_def_lib._apply_op_helper("TensorArrayReadV3", name, new
{
handle,
index,
flow_in,
dtype
});

return _op.output;
}

public static Tensor tensor_array_write_v3(Tensor handle, Tensor index, Tensor value, Tensor flow_in, string name = null)
{
var _op = _op_def_lib._apply_op_helper("TensorArrayWriteV3", name, new
{
handle,
index,
value,
flow_in
});

return _op.output;
}

public static Tensor tensor_array_size_v3(Tensor handle, Tensor flow_in, string name = null)
{
var _op = _op_def_lib._apply_op_helper("TensorArraySizeV3", name, new
{
handle,
flow_in
});

return _op.output;
}

public static Tensor tensor_array_gather_v3(Tensor handle, Tensor indices, Tensor flow_in,
TF_DataType dtype, TensorShape element_shape = null, string name = null)
{
var _op = _op_def_lib._apply_op_helper("TensorArrayGatherV3", name, new
{
handle,
indices,
dtype,
element_shape,
flow_in
});

return _op.output;
}

public static Tensor stack_v2(Tensor max_size, TF_DataType elem_type, string stack_name = "",
string name = null)
{
var _op = _op_def_lib._apply_op_helper("StackV2", name, new
{
max_size,
elem_type,
stack_name
});

return _op.output;
}

public static Tensor stack_push_v2(Tensor handle, Tensor elem, bool swap_memory = false,
string name = null)
{
var _op = _op_def_lib._apply_op_helper("StackPushV2", name, new
{
handle,
elem,
swap_memory
});

return _op.output;
}

public static Tensor stack_pop_v2(Tensor handle, TF_DataType elem_type, string name = null)
{
var _op = _op_def_lib._apply_op_helper("StackPopV2", name, new
{
handle,
elem_type
});

return _op.output;
}
} }
} }

+ 4
- 2
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/
using static Tensorflow.Binding;
namespace Tensorflow namespace Tensorflow
{ {
public static class gen_math_ops public static class gen_math_ops
@@ -280,7 +282,7 @@ namespace Tensorflow
/// <param name="dy"></param> /// <param name="dy"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
public static Tensor tanh_grad(Tensor y, Tensor dy, string name = "TanhGrad")
public static Tensor tanh_grad(Tensor y, Tensor dy, string name = null)
=> _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; => _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output;
public static Tensor floor(Tensor x, string name = null) public static Tensor floor(Tensor x, string name = null)
@@ -566,7 +568,7 @@ namespace Tensorflow
{ {
var _op = _op_def_lib._apply_op_helper("MatMul", name, args: new { a, b, transpose_a, transpose_b }); var _op = _op_def_lib._apply_op_helper("MatMul", name, args: new { a, b, transpose_a, transpose_b });
return _op.outputs[0];
return _op.output;
} }
/// <summary> /// <summary>


+ 19
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -159,6 +159,8 @@ namespace Tensorflow
}); });
} }


public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.greater_equal<Tx, Ty>(x, y, name: name);
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.equal(x, y, name: name); => gen_math_ops.equal(x, y, name: name);


@@ -543,6 +545,23 @@ namespace Tensorflow
public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null) public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.maximum(x, y, name: name); => gen_math_ops.maximum(x, y, name: name);


/// <summary>
/// Multiplies matrix `a` by matrix `b`, producing `a` * `b`.
/// </summary>
/// <param name="a"></param>
/// <param name="b"></param>
/// <param name="transpose_a">If `True`, `a` is transposed before multiplication.</param>
/// <param name="transpose_b">If `True`, `b` is transposed before multiplication.</param>
/// <param name="adjoint_a">If `True`, `a` is conjugated and transposed before multiplication.</param>
/// <param name="adjoint_b">If `True`, `b` is conjugated and transposed before multiplication.</param>
/// <param name="a_is_sparse">If `True`, `a` is treated as a sparse matrix.</param>
/// <param name="b_is_sparse">If `True`, `b` is treated as a sparse matrix.</param>
/// <param name="name">Name for the operation (optional).</param>
/// <returns>
/// A `Tensor` of the same type as `a` and `b` where each inner-most matrix is
/// the product of the corresponding matrices in `a` and `b`, e.g. if all
/// transpose or adjoint attributes are `False`:
/// </returns>
public static Tensor matmul(Tensor a, Tensor b, public static Tensor matmul(Tensor a, Tensor b,
bool transpose_a = false, bool transpose_b = false, bool transpose_a = false, bool transpose_b = false,
bool adjoint_a = false, bool adjoint_b = false, bool adjoint_a = false, bool adjoint_b = false,


+ 8
- 0
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -111,6 +111,14 @@ namespace Tensorflow
return noise_shape; return noise_shape;
} }


public static Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = null)
{
return tf_with(ops.name_scope(name, "in_top_k"), delegate
{
return gen_nn_ops.in_top_kv2(predictions, targets, k, name: name);
});
}

public static Tensor log_softmax(Tensor logits, int axis = -1, string name = null) public static Tensor log_softmax(Tensor logits, int axis = -1, string name = null)
{ {
return _softmax(logits, gen_nn_ops.log_softmax, axis, name); return _softmax(logits, gen_nn_ops.log_softmax, axis, name);


+ 1
- 1
src/TensorFlowNET.Core/Operations/random_ops.py.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow
return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope =>
{ {
name = scope; name = scope;
var tensorShape = _ShapeTensor(shape);
var tensorShape = tensor_util.shape_tensor(shape);
var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min");
var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max");
var rnd = gen_random_ops.random_uniform(tensorShape, dtype); var rnd = gen_random_ops.random_uniform(tensorShape, dtype);


+ 52
- 0
src/TensorFlowNET.Core/Operations/tensor_array_ops.cs View File

@@ -0,0 +1,52 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;

namespace Tensorflow
{
public class tensor_array_ops
{
/// <summary>
/// Builds a TensorArray with a new `flow` tensor.
/// </summary>
/// <param name="old_ta"></param>
/// <param name="flow"></param>
/// <returns></returns>
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow)
{
var impl = old_ta._implementation;

var new_ta = new TensorArray(
dtype: impl.dtype,
handle: impl.handle,
flow: flow,
infer_shape: impl.infer_shape,
colocate_with_first_write_call: impl.colocate_with_first_write_call);

var new_impl = new_ta._implementation;
new_impl._dynamic_size = impl._dynamic_size;
new_impl._colocate_with = impl._colocate_with;
new_impl._element_shape = impl._element_shape;
return new_ta;
}

public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow)
{
var impl = old_ta;

var new_ta = new TensorArray(
dtype: impl.dtype,
handle: impl.handle,
flow: flow,
infer_shape: impl.infer_shape,
colocate_with_first_write_call: impl.colocate_with_first_write_call);

var new_impl = new_ta._implementation;
new_impl._dynamic_size = impl._dynamic_size;
new_impl._colocate_with = impl._colocate_with;
new_impl._element_shape = impl._element_shape;
return new_ta;
}
}
}

+ 645
- 78
src/TensorFlowNET.Core/Protobuf/Config.cs View File

@@ -27,10 +27,10 @@ namespace Tensorflow {
"CiV0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29uZmlnLnByb3RvEgp0ZW5z", "CiV0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29uZmlnLnByb3RvEgp0ZW5z",
"b3JmbG93Gip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Nvc3RfZ3JhcGgu", "b3JmbG93Gip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Nvc3RfZ3JhcGgu",
"cHJvdG8aJXRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvZ3JhcGgucHJvdG8a", "cHJvdG8aJXRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvZ3JhcGgucHJvdG8a",
"KnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvc3RlcF9zdGF0cy5wcm90bxok",
"dGVuc29yZmxvdy9jb3JlL3Byb3RvYnVmL2RlYnVnLnByb3RvGiZ0ZW5zb3Jm",
"bG93L2NvcmUvcHJvdG9idWYvY2x1c3Rlci5wcm90bxoudGVuc29yZmxvdy9j",
"b3JlL3Byb3RvYnVmL3Jld3JpdGVyX2NvbmZpZy5wcm90byKtBAoKR1BVT3B0",
"KnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvc3RlcF9zdGF0cy5wcm90bxom",
"dGVuc29yZmxvdy9jb3JlL3Byb3RvYnVmL2NsdXN0ZXIucHJvdG8aJHRlbnNv",
"cmZsb3cvY29yZS9wcm90b2J1Zi9kZWJ1Zy5wcm90bxoudGVuc29yZmxvdy9j",
"b3JlL3Byb3RvYnVmL3Jld3JpdGVyX2NvbmZpZy5wcm90byK3BQoKR1BVT3B0",
"aW9ucxInCh9wZXJfcHJvY2Vzc19ncHVfbWVtb3J5X2ZyYWN0aW9uGAEgASgB", "aW9ucxInCh9wZXJfcHJvY2Vzc19ncHVfbWVtb3J5X2ZyYWN0aW9uGAEgASgB",
"EhQKDGFsbG93X2dyb3d0aBgEIAEoCBIWCg5hbGxvY2F0b3JfdHlwZRgCIAEo", "EhQKDGFsbG93X2dyb3d0aBgEIAEoCBIWCg5hbGxvY2F0b3JfdHlwZRgCIAEo",
"CRIfChdkZWZlcnJlZF9kZWxldGlvbl9ieXRlcxgDIAEoAxIbChN2aXNpYmxl", "CRIfChdkZWZlcnJlZF9kZWxldGlvbl9ieXRlcxgDIAEoAxIbChN2aXNpYmxl",
@@ -38,89 +38,102 @@ namespace Tensorflow {
"ZWNzGAYgASgFEiQKHHBvbGxpbmdfaW5hY3RpdmVfZGVsYXlfbXNlY3MYByAB", "ZWNzGAYgASgFEiQKHHBvbGxpbmdfaW5hY3RpdmVfZGVsYXlfbXNlY3MYByAB",
"KAUSHAoUZm9yY2VfZ3B1X2NvbXBhdGlibGUYCCABKAgSOQoMZXhwZXJpbWVu", "KAUSHAoUZm9yY2VfZ3B1X2NvbXBhdGlibGUYCCABKAgSOQoMZXhwZXJpbWVu",
"dGFsGAkgASgLMiMudGVuc29yZmxvdy5HUFVPcHRpb25zLkV4cGVyaW1lbnRh", "dGFsGAkgASgLMiMudGVuc29yZmxvdy5HUFVPcHRpb25zLkV4cGVyaW1lbnRh",
"bBrmAQoMRXhwZXJpbWVudGFsEksKD3ZpcnR1YWxfZGV2aWNlcxgBIAMoCzIy",
"bBrwAgoMRXhwZXJpbWVudGFsEksKD3ZpcnR1YWxfZGV2aWNlcxgBIAMoCzIy",
"LnRlbnNvcmZsb3cuR1BVT3B0aW9ucy5FeHBlcmltZW50YWwuVmlydHVhbERl", "LnRlbnNvcmZsb3cuR1BVT3B0aW9ucy5FeHBlcmltZW50YWwuVmlydHVhbERl",
"dmljZXMSGgoSdXNlX3VuaWZpZWRfbWVtb3J5GAIgASgIEiMKG251bV9kZXZf", "dmljZXMSGgoSdXNlX3VuaWZpZWRfbWVtb3J5GAIgASgIEiMKG251bV9kZXZf",
"dG9fZGV2X2NvcHlfc3RyZWFtcxgDIAEoBRIdChVjb2xsZWN0aXZlX3Jpbmdf", "dG9fZGV2X2NvcHlfc3RyZWFtcxgDIAEoBRIdChVjb2xsZWN0aXZlX3Jpbmdf",
"b3JkZXIYBCABKAkaKQoOVmlydHVhbERldmljZXMSFwoPbWVtb3J5X2xpbWl0",
"X21iGAEgAygCIoUDChBPcHRpbWl6ZXJPcHRpb25zEisKI2RvX2NvbW1vbl9z",
"dWJleHByZXNzaW9uX2VsaW1pbmF0aW9uGAEgASgIEhsKE2RvX2NvbnN0YW50",
"X2ZvbGRpbmcYAiABKAgSJAocbWF4X2ZvbGRlZF9jb25zdGFudF9pbl9ieXRl",
"cxgGIAEoAxIcChRkb19mdW5jdGlvbl9pbmxpbmluZxgEIAEoCBI1CglvcHRf",
"bGV2ZWwYAyABKA4yIi50ZW5zb3JmbG93Lk9wdGltaXplck9wdGlvbnMuTGV2",
"ZWwSRQoQZ2xvYmFsX2ppdF9sZXZlbBgFIAEoDjIrLnRlbnNvcmZsb3cuT3B0",
"aW1pemVyT3B0aW9ucy5HbG9iYWxKaXRMZXZlbCIgCgVMZXZlbBIGCgJMMRAA",
"Eg8KAkwwEP///////////wEiQwoOR2xvYmFsSml0TGV2ZWwSCwoHREVGQVVM",
"VBAAEhAKA09GRhD///////////8BEggKBE9OXzEQARIICgRPTl8yEAIi7gIK",
"DEdyYXBoT3B0aW9ucxIeChZlbmFibGVfcmVjdl9zY2hlZHVsaW5nGAIgASgI",
"EjcKEW9wdGltaXplcl9vcHRpb25zGAMgASgLMhwudGVuc29yZmxvdy5PcHRp",
"bWl6ZXJPcHRpb25zEhgKEGJ1aWxkX2Nvc3RfbW9kZWwYBCABKAMSHgoWYnVp",
"bGRfY29zdF9tb2RlbF9hZnRlchgJIAEoAxIUCgxpbmZlcl9zaGFwZXMYBSAB",
"KAgSGgoScGxhY2VfcHJ1bmVkX2dyYXBoGAYgASgIEiAKGGVuYWJsZV9iZmxv",
"YXQxNl9zZW5kcmVjdhgHIAEoCBIVCg10aW1lbGluZV9zdGVwGAggASgFEjMK",
"D3Jld3JpdGVfb3B0aW9ucxgKIAEoCzIaLnRlbnNvcmZsb3cuUmV3cml0ZXJD",
"b25maWdKBAgBEAJSJXNraXBfY29tbW9uX3N1YmV4cHJlc3Npb25fZWxpbWlu",
"YXRpb24iQQoVVGhyZWFkUG9vbE9wdGlvblByb3RvEhMKC251bV90aHJlYWRz",
"GAEgASgFEhMKC2dsb2JhbF9uYW1lGAIgASgJImwKClJQQ09wdGlvbnMSJAoc",
"dXNlX3JwY19mb3JfaW5wcm9jZXNzX21hc3RlchgBIAEoCBIdChVjb21wcmVz",
"c2lvbl9hbGdvcml0aG0YAiABKAkSGQoRY29tcHJlc3Npb25fbGV2ZWwYAyAB",
"KAUi3wYKC0NvbmZpZ1Byb3RvEj4KDGRldmljZV9jb3VudBgBIAMoCzIoLnRl",
"bnNvcmZsb3cuQ29uZmlnUHJvdG8uRGV2aWNlQ291bnRFbnRyeRIkChxpbnRy",
"YV9vcF9wYXJhbGxlbGlzbV90aHJlYWRzGAIgASgFEiQKHGludGVyX29wX3Bh",
"cmFsbGVsaXNtX3RocmVhZHMYBSABKAUSHwoXdXNlX3Blcl9zZXNzaW9uX3Ro",
"cmVhZHMYCSABKAgSRwocc2Vzc2lvbl9pbnRlcl9vcF90aHJlYWRfcG9vbBgM",
"IAMoCzIhLnRlbnNvcmZsb3cuVGhyZWFkUG9vbE9wdGlvblByb3RvEhgKEHBs",
"YWNlbWVudF9wZXJpb2QYAyABKAUSFgoOZGV2aWNlX2ZpbHRlcnMYBCADKAkS",
"KwoLZ3B1X29wdGlvbnMYBiABKAsyFi50ZW5zb3JmbG93LkdQVU9wdGlvbnMS",
"HAoUYWxsb3dfc29mdF9wbGFjZW1lbnQYByABKAgSHAoUbG9nX2RldmljZV9w",
"bGFjZW1lbnQYCCABKAgSLwoNZ3JhcGhfb3B0aW9ucxgKIAEoCzIYLnRlbnNv",
"cmZsb3cuR3JhcGhPcHRpb25zEh8KF29wZXJhdGlvbl90aW1lb3V0X2luX21z",
"GAsgASgDEisKC3JwY19vcHRpb25zGA0gASgLMhYudGVuc29yZmxvdy5SUENP",
"cHRpb25zEisKC2NsdXN0ZXJfZGVmGA4gASgLMhYudGVuc29yZmxvdy5DbHVz",
"dGVyRGVmEh0KFWlzb2xhdGVfc2Vzc2lvbl9zdGF0ZRgPIAEoCBI6CgxleHBl",
"cmltZW50YWwYECABKAsyJC50ZW5zb3JmbG93LkNvbmZpZ1Byb3RvLkV4cGVy",
"aW1lbnRhbBoyChBEZXZpY2VDb3VudEVudHJ5EgsKA2tleRgBIAEoCRINCgV2",
"YWx1ZRgCIAEoBToCOAEagwEKDEV4cGVyaW1lbnRhbBIfChdjb2xsZWN0aXZl",
"X2dyb3VwX2xlYWRlchgBIAEoCRIVCg1leGVjdXRvcl90eXBlGAMgASgJEhoK",
"EnJlY3ZfYnVmX21heF9jaHVuaxgEIAEoBRIZChF1c2VfbnVtYV9hZmZpbml0",
"eRgFIAEoCEoECAIQAyLYAwoKUnVuT3B0aW9ucxI2Cgt0cmFjZV9sZXZlbBgB",
"IAEoDjIhLnRlbnNvcmZsb3cuUnVuT3B0aW9ucy5UcmFjZUxldmVsEhUKDXRp",
"bWVvdXRfaW5fbXMYAiABKAMSHAoUaW50ZXJfb3BfdGhyZWFkX3Bvb2wYAyAB",
"KAUSHwoXb3V0cHV0X3BhcnRpdGlvbl9ncmFwaHMYBSABKAgSLwoNZGVidWdf",
"b3B0aW9ucxgGIAEoCzIYLnRlbnNvcmZsb3cuRGVidWdPcHRpb25zEioKInJl",
"cG9ydF90ZW5zb3JfYWxsb2NhdGlvbnNfdXBvbl9vb20YByABKAgSOQoMZXhw",
"ZXJpbWVudGFsGAggASgLMiMudGVuc29yZmxvdy5SdW5PcHRpb25zLkV4cGVy",
"aW1lbnRhbBpKCgxFeHBlcmltZW50YWwSHAoUY29sbGVjdGl2ZV9ncmFwaF9r",
"ZXkYASABKAMSHAoUdXNlX3J1bl9oYW5kbGVyX3Bvb2wYAiABKAgiUgoKVHJh",
"Y2VMZXZlbBIMCghOT19UUkFDRRAAEhIKDlNPRlRXQVJFX1RSQUNFEAESEgoO",
"SEFSRFdBUkVfVFJBQ0UQAhIOCgpGVUxMX1RSQUNFEANKBAgEEAUilgEKC1J1",
"bk1ldGFkYXRhEikKCnN0ZXBfc3RhdHMYASABKAsyFS50ZW5zb3JmbG93LlN0",
"ZXBTdGF0cxIsCgpjb3N0X2dyYXBoGAIgASgLMhgudGVuc29yZmxvdy5Db3N0",
"R3JhcGhEZWYSLgoQcGFydGl0aW9uX2dyYXBocxgDIAMoCzIULnRlbnNvcmZs",
"b3cuR3JhcGhEZWYiOgoQVGVuc29yQ29ubmVjdGlvbhITCgtmcm9tX3RlbnNv",
"chgBIAEoCRIRCgl0b190ZW5zb3IYAiABKAkisAMKD0NhbGxhYmxlT3B0aW9u",
"cxIMCgRmZWVkGAEgAygJEg0KBWZldGNoGAIgAygJEg4KBnRhcmdldBgDIAMo",
"CRIrCgtydW5fb3B0aW9ucxgEIAEoCzIWLnRlbnNvcmZsb3cuUnVuT3B0aW9u",
"cxI3ChF0ZW5zb3JfY29ubmVjdGlvbhgFIAMoCzIcLnRlbnNvcmZsb3cuVGVu",
"c29yQ29ubmVjdGlvbhJCCgxmZWVkX2RldmljZXMYBiADKAsyLC50ZW5zb3Jm",
"bG93LkNhbGxhYmxlT3B0aW9ucy5GZWVkRGV2aWNlc0VudHJ5EkQKDWZldGNo",
"X2RldmljZXMYByADKAsyLS50ZW5zb3JmbG93LkNhbGxhYmxlT3B0aW9ucy5G",
"ZXRjaERldmljZXNFbnRyeRIXCg9mZXRjaF9za2lwX3N5bmMYCCABKAgaMgoQ",
"RmVlZERldmljZXNFbnRyeRILCgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAk6",
"AjgBGjMKEUZldGNoRGV2aWNlc0VudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1",
"ZRgCIAEoCToCOAFCLQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQgxDb25m",
"aWdQcm90b3NQAfgBAWIGcHJvdG8z"));
"b3JkZXIYBCABKAkSHQoVdGltZXN0YW1wZWRfYWxsb2NhdG9yGAUgASgIEiMK",
"G2tlcm5lbF90cmFja2VyX21heF9pbnRlcnZhbBgHIAEoBRIgChhrZXJuZWxf",
"dHJhY2tlcl9tYXhfYnl0ZXMYCCABKAUSIgoaa2VybmVsX3RyYWNrZXJfbWF4",
"X3BlbmRpbmcYCSABKAUaKQoOVmlydHVhbERldmljZXMSFwoPbWVtb3J5X2xp",
"bWl0X21iGAEgAygCIoUDChBPcHRpbWl6ZXJPcHRpb25zEisKI2RvX2NvbW1v",
"bl9zdWJleHByZXNzaW9uX2VsaW1pbmF0aW9uGAEgASgIEhsKE2RvX2NvbnN0",
"YW50X2ZvbGRpbmcYAiABKAgSJAocbWF4X2ZvbGRlZF9jb25zdGFudF9pbl9i",
"eXRlcxgGIAEoAxIcChRkb19mdW5jdGlvbl9pbmxpbmluZxgEIAEoCBI1Cglv",
"cHRfbGV2ZWwYAyABKA4yIi50ZW5zb3JmbG93Lk9wdGltaXplck9wdGlvbnMu",
"TGV2ZWwSRQoQZ2xvYmFsX2ppdF9sZXZlbBgFIAEoDjIrLnRlbnNvcmZsb3cu",
"T3B0aW1pemVyT3B0aW9ucy5HbG9iYWxKaXRMZXZlbCIgCgVMZXZlbBIGCgJM",
"MRAAEg8KAkwwEP///////////wEiQwoOR2xvYmFsSml0TGV2ZWwSCwoHREVG",
"QVVMVBAAEhAKA09GRhD///////////8BEggKBE9OXzEQARIICgRPTl8yEAIi",
"7gIKDEdyYXBoT3B0aW9ucxIeChZlbmFibGVfcmVjdl9zY2hlZHVsaW5nGAIg",
"ASgIEjcKEW9wdGltaXplcl9vcHRpb25zGAMgASgLMhwudGVuc29yZmxvdy5P",
"cHRpbWl6ZXJPcHRpb25zEhgKEGJ1aWxkX2Nvc3RfbW9kZWwYBCABKAMSHgoW",
"YnVpbGRfY29zdF9tb2RlbF9hZnRlchgJIAEoAxIUCgxpbmZlcl9zaGFwZXMY",
"BSABKAgSGgoScGxhY2VfcHJ1bmVkX2dyYXBoGAYgASgIEiAKGGVuYWJsZV9i",
"ZmxvYXQxNl9zZW5kcmVjdhgHIAEoCBIVCg10aW1lbGluZV9zdGVwGAggASgF",
"EjMKD3Jld3JpdGVfb3B0aW9ucxgKIAEoCzIaLnRlbnNvcmZsb3cuUmV3cml0",
"ZXJDb25maWdKBAgBEAJSJXNraXBfY29tbW9uX3N1YmV4cHJlc3Npb25fZWxp",
"bWluYXRpb24iQQoVVGhyZWFkUG9vbE9wdGlvblByb3RvEhMKC251bV90aHJl",
"YWRzGAEgASgFEhMKC2dsb2JhbF9uYW1lGAIgASgJImwKClJQQ09wdGlvbnMS",
"JAocdXNlX3JwY19mb3JfaW5wcm9jZXNzX21hc3RlchgBIAEoCBIdChVjb21w",
"cmVzc2lvbl9hbGdvcml0aG0YAiABKAkSGQoRY29tcHJlc3Npb25fbGV2ZWwY",
"AyABKAUisggKC0NvbmZpZ1Byb3RvEj4KDGRldmljZV9jb3VudBgBIAMoCzIo",
"LnRlbnNvcmZsb3cuQ29uZmlnUHJvdG8uRGV2aWNlQ291bnRFbnRyeRIkChxp",
"bnRyYV9vcF9wYXJhbGxlbGlzbV90aHJlYWRzGAIgASgFEiQKHGludGVyX29w",
"X3BhcmFsbGVsaXNtX3RocmVhZHMYBSABKAUSHwoXdXNlX3Blcl9zZXNzaW9u",
"X3RocmVhZHMYCSABKAgSRwocc2Vzc2lvbl9pbnRlcl9vcF90aHJlYWRfcG9v",
"bBgMIAMoCzIhLnRlbnNvcmZsb3cuVGhyZWFkUG9vbE9wdGlvblByb3RvEhgK",
"EHBsYWNlbWVudF9wZXJpb2QYAyABKAUSFgoOZGV2aWNlX2ZpbHRlcnMYBCAD",
"KAkSKwoLZ3B1X29wdGlvbnMYBiABKAsyFi50ZW5zb3JmbG93LkdQVU9wdGlv",
"bnMSHAoUYWxsb3dfc29mdF9wbGFjZW1lbnQYByABKAgSHAoUbG9nX2Rldmlj",
"ZV9wbGFjZW1lbnQYCCABKAgSLwoNZ3JhcGhfb3B0aW9ucxgKIAEoCzIYLnRl",
"bnNvcmZsb3cuR3JhcGhPcHRpb25zEh8KF29wZXJhdGlvbl90aW1lb3V0X2lu",
"X21zGAsgASgDEisKC3JwY19vcHRpb25zGA0gASgLMhYudGVuc29yZmxvdy5S",
"UENPcHRpb25zEisKC2NsdXN0ZXJfZGVmGA4gASgLMhYudGVuc29yZmxvdy5D",
"bHVzdGVyRGVmEh0KFWlzb2xhdGVfc2Vzc2lvbl9zdGF0ZRgPIAEoCBI6Cgxl",
"eHBlcmltZW50YWwYECABKAsyJC50ZW5zb3JmbG93LkNvbmZpZ1Byb3RvLkV4",
"cGVyaW1lbnRhbBoyChBEZXZpY2VDb3VudEVudHJ5EgsKA2tleRgBIAEoCRIN",
"CgV2YWx1ZRgCIAEoBToCOAEa1gIKDEV4cGVyaW1lbnRhbBIfChdjb2xsZWN0",
"aXZlX2dyb3VwX2xlYWRlchgBIAEoCRIVCg1leGVjdXRvcl90eXBlGAMgASgJ",
"EhoKEnJlY3ZfYnVmX21heF9jaHVuaxgEIAEoBRIZChF1c2VfbnVtYV9hZmZp",
"bml0eRgFIAEoCBI1Ci1jb2xsZWN0aXZlX2RldGVybWluaXN0aWNfc2VxdWVu",
"dGlhbF9leGVjdXRpb24YBiABKAgSFwoPY29sbGVjdGl2ZV9uY2NsGAcgASgI",
"EjYKLnNoYXJlX3Nlc3Npb25fc3RhdGVfaW5fY2x1c3RlcnNwZWNfcHJvcGFn",
"YXRpb24YCCABKAgSHwoXZGlzYWJsZV90aHJlYWRfc3Bpbm5pbmcYCSABKAgS",
"KAogc2hhcmVfY2x1c3Rlcl9kZXZpY2VzX2luX3Nlc3Npb24YCiABKAhKBAgC",
"EAMi2AMKClJ1bk9wdGlvbnMSNgoLdHJhY2VfbGV2ZWwYASABKA4yIS50ZW5z",
"b3JmbG93LlJ1bk9wdGlvbnMuVHJhY2VMZXZlbBIVCg10aW1lb3V0X2luX21z",
"GAIgASgDEhwKFGludGVyX29wX3RocmVhZF9wb29sGAMgASgFEh8KF291dHB1",
"dF9wYXJ0aXRpb25fZ3JhcGhzGAUgASgIEi8KDWRlYnVnX29wdGlvbnMYBiAB",
"KAsyGC50ZW5zb3JmbG93LkRlYnVnT3B0aW9ucxIqCiJyZXBvcnRfdGVuc29y",
"X2FsbG9jYXRpb25zX3Vwb25fb29tGAcgASgIEjkKDGV4cGVyaW1lbnRhbBgI",
"IAEoCzIjLnRlbnNvcmZsb3cuUnVuT3B0aW9ucy5FeHBlcmltZW50YWwaSgoM",
"RXhwZXJpbWVudGFsEhwKFGNvbGxlY3RpdmVfZ3JhcGhfa2V5GAEgASgDEhwK",
"FHVzZV9ydW5faGFuZGxlcl9wb29sGAIgASgIIlIKClRyYWNlTGV2ZWwSDAoI",
"Tk9fVFJBQ0UQABISCg5TT0ZUV0FSRV9UUkFDRRABEhIKDkhBUkRXQVJFX1RS",
"QUNFEAISDgoKRlVMTF9UUkFDRRADSgQIBBAFIocDCgtSdW5NZXRhZGF0YRIp",
"CgpzdGVwX3N0YXRzGAEgASgLMhUudGVuc29yZmxvdy5TdGVwU3RhdHMSLAoK",
"Y29zdF9ncmFwaBgCIAEoCzIYLnRlbnNvcmZsb3cuQ29zdEdyYXBoRGVmEi4K",
"EHBhcnRpdGlvbl9ncmFwaHMYAyADKAsyFC50ZW5zb3JmbG93LkdyYXBoRGVm",
"Ej8KD2Z1bmN0aW9uX2dyYXBocxgEIAMoCzImLnRlbnNvcmZsb3cuUnVuTWV0",
"YWRhdGEuRnVuY3Rpb25HcmFwaHMarQEKDkZ1bmN0aW9uR3JhcGhzEi4KEHBh",
"cnRpdGlvbl9ncmFwaHMYASADKAsyFC50ZW5zb3JmbG93LkdyYXBoRGVmEjQK",
"FnByZV9vcHRpbWl6YXRpb25fZ3JhcGgYAiABKAsyFC50ZW5zb3JmbG93Lkdy",
"YXBoRGVmEjUKF3Bvc3Rfb3B0aW1pemF0aW9uX2dyYXBoGAMgASgLMhQudGVu",
"c29yZmxvdy5HcmFwaERlZiI6ChBUZW5zb3JDb25uZWN0aW9uEhMKC2Zyb21f",
"dGVuc29yGAEgASgJEhEKCXRvX3RlbnNvchgCIAEoCSKwAwoPQ2FsbGFibGVP",
"cHRpb25zEgwKBGZlZWQYASADKAkSDQoFZmV0Y2gYAiADKAkSDgoGdGFyZ2V0",
"GAMgAygJEisKC3J1bl9vcHRpb25zGAQgASgLMhYudGVuc29yZmxvdy5SdW5P",
"cHRpb25zEjcKEXRlbnNvcl9jb25uZWN0aW9uGAUgAygLMhwudGVuc29yZmxv",
"dy5UZW5zb3JDb25uZWN0aW9uEkIKDGZlZWRfZGV2aWNlcxgGIAMoCzIsLnRl",
"bnNvcmZsb3cuQ2FsbGFibGVPcHRpb25zLkZlZWREZXZpY2VzRW50cnkSRAoN",
"ZmV0Y2hfZGV2aWNlcxgHIAMoCzItLnRlbnNvcmZsb3cuQ2FsbGFibGVPcHRp",
"b25zLkZldGNoRGV2aWNlc0VudHJ5EhcKD2ZldGNoX3NraXBfc3luYxgIIAEo",
"CBoyChBGZWVkRGV2aWNlc0VudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgC",
"IAEoCToCOAEaMwoRRmV0Y2hEZXZpY2VzRW50cnkSCwoDa2V5GAEgASgJEg0K",
"BXZhbHVlGAIgASgJOgI4AUItChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtC",
"DENvbmZpZ1Byb3Rvc1AB+AEBYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::Tensorflow.CostGraphReflection.Descriptor, global::Tensorflow.GraphReflection.Descriptor, global::Tensorflow.StepStatsReflection.Descriptor, global::Tensorflow.DebugReflection.Descriptor, global::Tensorflow.ClusterReflection.Descriptor, global::Tensorflow.RewriterConfigReflection.Descriptor, },
new pbr::FileDescriptor[] { global::Tensorflow.CostGraphReflection.Descriptor, global::Tensorflow.GraphReflection.Descriptor, global::Tensorflow.StepStatsReflection.Descriptor, global::Tensorflow.ClusterReflection.Descriptor, global::Tensorflow.DebugReflection.Descriptor, global::Tensorflow.RewriterConfigReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions), global::Tensorflow.GPUOptions.Parser, new[]{ "PerProcessGpuMemoryFraction", "AllowGrowth", "AllocatorType", "DeferredDeletionBytes", "VisibleDeviceList", "PollingActiveDelayUsecs", "PollingInactiveDelayMsecs", "ForceGpuCompatible", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental), global::Tensorflow.GPUOptions.Types.Experimental.Parser, new[]{ "VirtualDevices", "UseUnifiedMemory", "NumDevToDevCopyStreams", "CollectiveRingOrder" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices), global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices.Parser, new[]{ "MemoryLimitMb" }, null, null, null)})}),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions), global::Tensorflow.GPUOptions.Parser, new[]{ "PerProcessGpuMemoryFraction", "AllowGrowth", "AllocatorType", "DeferredDeletionBytes", "VisibleDeviceList", "PollingActiveDelayUsecs", "PollingInactiveDelayMsecs", "ForceGpuCompatible", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental), global::Tensorflow.GPUOptions.Types.Experimental.Parser, new[]{ "VirtualDevices", "UseUnifiedMemory", "NumDevToDevCopyStreams", "CollectiveRingOrder", "TimestampedAllocator", "KernelTrackerMaxInterval", "KernelTrackerMaxBytes", "KernelTrackerMaxPending" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices), global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices.Parser, new[]{ "MemoryLimitMb" }, null, null, null)})}),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OptimizerOptions), global::Tensorflow.OptimizerOptions.Parser, new[]{ "DoCommonSubexpressionElimination", "DoConstantFolding", "MaxFoldedConstantInBytes", "DoFunctionInlining", "OptLevel", "GlobalJitLevel" }, null, new[]{ typeof(global::Tensorflow.OptimizerOptions.Types.Level), typeof(global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel) }, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OptimizerOptions), global::Tensorflow.OptimizerOptions.Parser, new[]{ "DoCommonSubexpressionElimination", "DoConstantFolding", "MaxFoldedConstantInBytes", "DoFunctionInlining", "OptLevel", "GlobalJitLevel" }, null, new[]{ typeof(global::Tensorflow.OptimizerOptions.Types.Level), typeof(global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel) }, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphOptions), global::Tensorflow.GraphOptions.Parser, new[]{ "EnableRecvScheduling", "OptimizerOptions", "BuildCostModel", "BuildCostModelAfter", "InferShapes", "PlacePrunedGraph", "EnableBfloat16Sendrecv", "TimelineStep", "RewriteOptions" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphOptions), global::Tensorflow.GraphOptions.Parser, new[]{ "EnableRecvScheduling", "OptimizerOptions", "BuildCostModel", "BuildCostModelAfter", "InferShapes", "PlacePrunedGraph", "EnableBfloat16Sendrecv", "TimelineStep", "RewriteOptions" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ThreadPoolOptionProto), global::Tensorflow.ThreadPoolOptionProto.Parser, new[]{ "NumThreads", "GlobalName" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ThreadPoolOptionProto), global::Tensorflow.ThreadPoolOptionProto.Parser, new[]{ "NumThreads", "GlobalName" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RPCOptions), global::Tensorflow.RPCOptions.Parser, new[]{ "UseRpcForInprocessMaster", "CompressionAlgorithm", "CompressionLevel" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RPCOptions), global::Tensorflow.RPCOptions.Parser, new[]{ "UseRpcForInprocessMaster", "CompressionAlgorithm", "CompressionLevel" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto), global::Tensorflow.ConfigProto.Parser, new[]{ "DeviceCount", "IntraOpParallelismThreads", "InterOpParallelismThreads", "UsePerSessionThreads", "SessionInterOpThreadPool", "PlacementPeriod", "DeviceFilters", "GpuOptions", "AllowSoftPlacement", "LogDevicePlacement", "GraphOptions", "OperationTimeoutInMs", "RpcOptions", "ClusterDef", "IsolateSessionState", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto.Types.Experimental), global::Tensorflow.ConfigProto.Types.Experimental.Parser, new[]{ "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity" }, null, null, null)}),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto), global::Tensorflow.ConfigProto.Parser, new[]{ "DeviceCount", "IntraOpParallelismThreads", "InterOpParallelismThreads", "UsePerSessionThreads", "SessionInterOpThreadPool", "PlacementPeriod", "DeviceFilters", "GpuOptions", "AllowSoftPlacement", "LogDevicePlacement", "GraphOptions", "OperationTimeoutInMs", "RpcOptions", "ClusterDef", "IsolateSessionState", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto.Types.Experimental), global::Tensorflow.ConfigProto.Types.Experimental.Parser, new[]{ "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity", "CollectiveDeterministicSequentialExecution", "CollectiveNccl", "ShareSessionStateInClusterspecPropagation", "DisableThreadSpinning", "ShareClusterDevicesInSession" }, null, null, null)}),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions), global::Tensorflow.RunOptions.Parser, new[]{ "TraceLevel", "TimeoutInMs", "InterOpThreadPool", "OutputPartitionGraphs", "DebugOptions", "ReportTensorAllocationsUponOom", "Experimental" }, null, new[]{ typeof(global::Tensorflow.RunOptions.Types.TraceLevel) }, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions.Types.Experimental), global::Tensorflow.RunOptions.Types.Experimental.Parser, new[]{ "CollectiveGraphKey", "UseRunHandlerPool" }, null, null, null)}), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions), global::Tensorflow.RunOptions.Parser, new[]{ "TraceLevel", "TimeoutInMs", "InterOpThreadPool", "OutputPartitionGraphs", "DebugOptions", "ReportTensorAllocationsUponOom", "Experimental" }, null, new[]{ typeof(global::Tensorflow.RunOptions.Types.TraceLevel) }, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions.Types.Experimental), global::Tensorflow.RunOptions.Types.Experimental.Parser, new[]{ "CollectiveGraphKey", "UseRunHandlerPool" }, null, null, null)}),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata), global::Tensorflow.RunMetadata.Parser, new[]{ "StepStats", "CostGraph", "PartitionGraphs" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata), global::Tensorflow.RunMetadata.Parser, new[]{ "StepStats", "CostGraph", "PartitionGraphs", "FunctionGraphs" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata.Types.FunctionGraphs), global::Tensorflow.RunMetadata.Types.FunctionGraphs.Parser, new[]{ "PartitionGraphs", "PreOptimizationGraph", "PostOptimizationGraph" }, null, null, null)}),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorConnection), global::Tensorflow.TensorConnection.Parser, new[]{ "FromTensor", "ToTensor" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorConnection), global::Tensorflow.TensorConnection.Parser, new[]{ "FromTensor", "ToTensor" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CallableOptions), global::Tensorflow.CallableOptions.Parser, new[]{ "Feed", "Fetch", "Target", "RunOptions", "TensorConnection", "FeedDevices", "FetchDevices", "FetchSkipSync" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }) new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CallableOptions), global::Tensorflow.CallableOptions.Parser, new[]{ "Feed", "Fetch", "Target", "RunOptions", "TensorConnection", "FeedDevices", "FetchDevices", "FetchSkipSync" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, })
})); }));
@@ -605,6 +618,10 @@ namespace Tensorflow {
useUnifiedMemory_ = other.useUnifiedMemory_; useUnifiedMemory_ = other.useUnifiedMemory_;
numDevToDevCopyStreams_ = other.numDevToDevCopyStreams_; numDevToDevCopyStreams_ = other.numDevToDevCopyStreams_;
collectiveRingOrder_ = other.collectiveRingOrder_; collectiveRingOrder_ = other.collectiveRingOrder_;
timestampedAllocator_ = other.timestampedAllocator_;
kernelTrackerMaxInterval_ = other.kernelTrackerMaxInterval_;
kernelTrackerMaxBytes_ = other.kernelTrackerMaxBytes_;
kernelTrackerMaxPending_ = other.kernelTrackerMaxPending_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
} }


@@ -703,6 +720,77 @@ namespace Tensorflow {
} }
} }


/// <summary>Field number for the "timestamped_allocator" field.</summary>
public const int TimestampedAllocatorFieldNumber = 5;
private bool timestampedAllocator_;
/// <summary>
/// If true then extra work is done by GPUDevice and GPUBFCAllocator to
/// keep track of when GPU memory is freed and when kernels actually
/// complete so that we can know when a nominally free memory chunk
/// is really not subject to pending use.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool TimestampedAllocator {
get { return timestampedAllocator_; }
set {
timestampedAllocator_ = value;
}
}

/// <summary>Field number for the "kernel_tracker_max_interval" field.</summary>
public const int KernelTrackerMaxIntervalFieldNumber = 7;
private int kernelTrackerMaxInterval_;
/// <summary>
/// Parameters for GPUKernelTracker. By default no kernel tracking is done.
/// Note that timestamped_allocator is only effective if some tracking is
/// specified.
///
/// If kernel_tracker_max_interval = n > 0, then a tracking event
/// is inserted after every n kernels without an event.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int KernelTrackerMaxInterval {
get { return kernelTrackerMaxInterval_; }
set {
kernelTrackerMaxInterval_ = value;
}
}

/// <summary>Field number for the "kernel_tracker_max_bytes" field.</summary>
public const int KernelTrackerMaxBytesFieldNumber = 8;
private int kernelTrackerMaxBytes_;
/// <summary>
/// If kernel_tracker_max_bytes = n > 0, then a tracking event is
/// inserted after every series of kernels allocating a sum of
/// memory >= n. If one kernel allocates b * n bytes, then one
/// event will be inserted after it, but it will count as b against
/// the pending limit.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int KernelTrackerMaxBytes {
get { return kernelTrackerMaxBytes_; }
set {
kernelTrackerMaxBytes_ = value;
}
}

/// <summary>Field number for the "kernel_tracker_max_pending" field.</summary>
public const int KernelTrackerMaxPendingFieldNumber = 9;
private int kernelTrackerMaxPending_;
/// <summary>
/// If kernel_tracker_max_pending > 0 then no more than this many
/// tracking events can be outstanding at a time. An attempt to
/// launch an additional kernel will stall until an event
/// completes.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int KernelTrackerMaxPending {
get { return kernelTrackerMaxPending_; }
set {
kernelTrackerMaxPending_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) { public override bool Equals(object other) {
return Equals(other as Experimental); return Equals(other as Experimental);
@@ -720,6 +808,10 @@ namespace Tensorflow {
if (UseUnifiedMemory != other.UseUnifiedMemory) return false; if (UseUnifiedMemory != other.UseUnifiedMemory) return false;
if (NumDevToDevCopyStreams != other.NumDevToDevCopyStreams) return false; if (NumDevToDevCopyStreams != other.NumDevToDevCopyStreams) return false;
if (CollectiveRingOrder != other.CollectiveRingOrder) return false; if (CollectiveRingOrder != other.CollectiveRingOrder) return false;
if (TimestampedAllocator != other.TimestampedAllocator) return false;
if (KernelTrackerMaxInterval != other.KernelTrackerMaxInterval) return false;
if (KernelTrackerMaxBytes != other.KernelTrackerMaxBytes) return false;
if (KernelTrackerMaxPending != other.KernelTrackerMaxPending) return false;
return Equals(_unknownFields, other._unknownFields); return Equals(_unknownFields, other._unknownFields);
} }


@@ -730,6 +822,10 @@ namespace Tensorflow {
if (UseUnifiedMemory != false) hash ^= UseUnifiedMemory.GetHashCode(); if (UseUnifiedMemory != false) hash ^= UseUnifiedMemory.GetHashCode();
if (NumDevToDevCopyStreams != 0) hash ^= NumDevToDevCopyStreams.GetHashCode(); if (NumDevToDevCopyStreams != 0) hash ^= NumDevToDevCopyStreams.GetHashCode();
if (CollectiveRingOrder.Length != 0) hash ^= CollectiveRingOrder.GetHashCode(); if (CollectiveRingOrder.Length != 0) hash ^= CollectiveRingOrder.GetHashCode();
if (TimestampedAllocator != false) hash ^= TimestampedAllocator.GetHashCode();
if (KernelTrackerMaxInterval != 0) hash ^= KernelTrackerMaxInterval.GetHashCode();
if (KernelTrackerMaxBytes != 0) hash ^= KernelTrackerMaxBytes.GetHashCode();
if (KernelTrackerMaxPending != 0) hash ^= KernelTrackerMaxPending.GetHashCode();
if (_unknownFields != null) { if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode(); hash ^= _unknownFields.GetHashCode();
} }
@@ -756,6 +852,22 @@ namespace Tensorflow {
output.WriteRawTag(34); output.WriteRawTag(34);
output.WriteString(CollectiveRingOrder); output.WriteString(CollectiveRingOrder);
} }
if (TimestampedAllocator != false) {
output.WriteRawTag(40);
output.WriteBool(TimestampedAllocator);
}
if (KernelTrackerMaxInterval != 0) {
output.WriteRawTag(56);
output.WriteInt32(KernelTrackerMaxInterval);
}
if (KernelTrackerMaxBytes != 0) {
output.WriteRawTag(64);
output.WriteInt32(KernelTrackerMaxBytes);
}
if (KernelTrackerMaxPending != 0) {
output.WriteRawTag(72);
output.WriteInt32(KernelTrackerMaxPending);
}
if (_unknownFields != null) { if (_unknownFields != null) {
_unknownFields.WriteTo(output); _unknownFields.WriteTo(output);
} }
@@ -774,6 +886,18 @@ namespace Tensorflow {
if (CollectiveRingOrder.Length != 0) { if (CollectiveRingOrder.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(CollectiveRingOrder); size += 1 + pb::CodedOutputStream.ComputeStringSize(CollectiveRingOrder);
} }
if (TimestampedAllocator != false) {
size += 1 + 1;
}
if (KernelTrackerMaxInterval != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxInterval);
}
if (KernelTrackerMaxBytes != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxBytes);
}
if (KernelTrackerMaxPending != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxPending);
}
if (_unknownFields != null) { if (_unknownFields != null) {
size += _unknownFields.CalculateSize(); size += _unknownFields.CalculateSize();
} }
@@ -795,6 +919,18 @@ namespace Tensorflow {
if (other.CollectiveRingOrder.Length != 0) { if (other.CollectiveRingOrder.Length != 0) {
CollectiveRingOrder = other.CollectiveRingOrder; CollectiveRingOrder = other.CollectiveRingOrder;
} }
if (other.TimestampedAllocator != false) {
TimestampedAllocator = other.TimestampedAllocator;
}
if (other.KernelTrackerMaxInterval != 0) {
KernelTrackerMaxInterval = other.KernelTrackerMaxInterval;
}
if (other.KernelTrackerMaxBytes != 0) {
KernelTrackerMaxBytes = other.KernelTrackerMaxBytes;
}
if (other.KernelTrackerMaxPending != 0) {
KernelTrackerMaxPending = other.KernelTrackerMaxPending;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
} }


@@ -822,6 +958,22 @@ namespace Tensorflow {
CollectiveRingOrder = input.ReadString(); CollectiveRingOrder = input.ReadString();
break; break;
} }
case 40: {
TimestampedAllocator = input.ReadBool();
break;
}
case 56: {
KernelTrackerMaxInterval = input.ReadInt32();
break;
}
case 64: {
KernelTrackerMaxBytes = input.ReadInt32();
break;
}
case 72: {
KernelTrackerMaxPending = input.ReadInt32();
break;
}
} }
} }
} }
@@ -2189,6 +2341,7 @@ namespace Tensorflow {
/// inter_op_parallelism_threads available in each process. /// inter_op_parallelism_threads available in each process.
/// ///
/// 0 means the system picks an appropriate number. /// 0 means the system picks an appropriate number.
/// Negative means all operations are performed in caller's thread.
/// ///
/// Note that the first Session created in the process sets the /// Note that the first Session created in the process sets the
/// number of threads for all future sessions unless use_per_session_threads is /// number of threads for all future sessions unless use_per_session_threads is
@@ -2397,7 +2550,8 @@ namespace Tensorflow {
private bool isolateSessionState_; private bool isolateSessionState_;
/// <summary> /// <summary>
/// If true, any resources such as Variables used in the session will not be /// If true, any resources such as Variables used in the session will not be
/// shared with other sessions.
/// shared with other sessions. However, when clusterspec propagation is
/// enabled, this field is ignored and sessions are always isolated.
/// </summary> /// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool IsolateSessionState { public bool IsolateSessionState {
@@ -2787,6 +2941,11 @@ namespace Tensorflow {
executorType_ = other.executorType_; executorType_ = other.executorType_;
recvBufMaxChunk_ = other.recvBufMaxChunk_; recvBufMaxChunk_ = other.recvBufMaxChunk_;
useNumaAffinity_ = other.useNumaAffinity_; useNumaAffinity_ = other.useNumaAffinity_;
collectiveDeterministicSequentialExecution_ = other.collectiveDeterministicSequentialExecution_;
collectiveNccl_ = other.collectiveNccl_;
shareSessionStateInClusterspecPropagation_ = other.shareSessionStateInClusterspecPropagation_;
disableThreadSpinning_ = other.disableThreadSpinning_;
shareClusterDevicesInSession_ = other.shareClusterDevicesInSession_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
} }


@@ -2856,6 +3015,103 @@ namespace Tensorflow {
} }
} }


/// <summary>Field number for the "collective_deterministic_sequential_execution" field.</summary>
public const int CollectiveDeterministicSequentialExecutionFieldNumber = 6;
private bool collectiveDeterministicSequentialExecution_;
/// <summary>
/// If true, make collective op execution order sequential and deterministic
/// for potentially concurrent collective instances.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool CollectiveDeterministicSequentialExecution {
get { return collectiveDeterministicSequentialExecution_; }
set {
collectiveDeterministicSequentialExecution_ = value;
}
}

/// <summary>Field number for the "collective_nccl" field.</summary>
public const int CollectiveNcclFieldNumber = 7;
private bool collectiveNccl_;
/// <summary>
/// If true, use NCCL for CollectiveOps. This feature is highly
/// experimental.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool CollectiveNccl {
get { return collectiveNccl_; }
set {
collectiveNccl_ = value;
}
}

/// <summary>Field number for the "share_session_state_in_clusterspec_propagation" field.</summary>
public const int ShareSessionStateInClusterspecPropagationFieldNumber = 8;
private bool shareSessionStateInClusterspecPropagation_;
/// <summary>
/// In the following, session state means the value of a variable, elements
/// in a hash table, or any other resource, accessible by worker sessions
/// held by a TF server.
///
/// When ClusterSpec propagation is enabled, the value of
/// isolate_session_state is ignored when deciding whether to share session
/// states in a TF server (for backwards compatibility reasons).
/// - If share_session_state_in_clusterspec_propagation is true, the session
/// states are shared.
/// - If share_session_state_in_clusterspec_propagation is false, session
/// states are isolated.
///
/// When clusterspec propagation is not used, the value of
/// share_session_state_in_clusterspec_propagation is ignored when deciding
/// whether to share session states in a TF server.
/// - If isolate_session_state is true, session states are isolated.
/// - If isolate_session_state is false, session states are shared.
///
/// TODO(b/129330037): Add a single API that consistently treats
/// isolate_session_state and ClusterSpec propagation.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool ShareSessionStateInClusterspecPropagation {
get { return shareSessionStateInClusterspecPropagation_; }
set {
shareSessionStateInClusterspecPropagation_ = value;
}
}

/// <summary>Field number for the "disable_thread_spinning" field.</summary>
public const int DisableThreadSpinningFieldNumber = 9;
private bool disableThreadSpinning_;
/// <summary>
/// If using a direct session, disable spinning while waiting for work in
/// the thread pool. This may result in higher latency for completing ops,
/// but in the case where there is a lot of spinning may result in lower
/// CPU usage.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool DisableThreadSpinning {
get { return disableThreadSpinning_; }
set {
disableThreadSpinning_ = value;
}
}

/// <summary>Field number for the "share_cluster_devices_in_session" field.</summary>
public const int ShareClusterDevicesInSessionFieldNumber = 10;
private bool shareClusterDevicesInSession_;
/// <summary>
/// When true, WorkerSessions are created with device attributes from the
/// full cluster.
/// This is helpful when a worker wants to partition a graph
/// (for example during a PartitionedCallOp).
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool ShareClusterDevicesInSession {
get { return shareClusterDevicesInSession_; }
set {
shareClusterDevicesInSession_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) { public override bool Equals(object other) {
return Equals(other as Experimental); return Equals(other as Experimental);
@@ -2873,6 +3129,11 @@ namespace Tensorflow {
if (ExecutorType != other.ExecutorType) return false; if (ExecutorType != other.ExecutorType) return false;
if (RecvBufMaxChunk != other.RecvBufMaxChunk) return false; if (RecvBufMaxChunk != other.RecvBufMaxChunk) return false;
if (UseNumaAffinity != other.UseNumaAffinity) return false; if (UseNumaAffinity != other.UseNumaAffinity) return false;
if (CollectiveDeterministicSequentialExecution != other.CollectiveDeterministicSequentialExecution) return false;
if (CollectiveNccl != other.CollectiveNccl) return false;
if (ShareSessionStateInClusterspecPropagation != other.ShareSessionStateInClusterspecPropagation) return false;
if (DisableThreadSpinning != other.DisableThreadSpinning) return false;
if (ShareClusterDevicesInSession != other.ShareClusterDevicesInSession) return false;
return Equals(_unknownFields, other._unknownFields); return Equals(_unknownFields, other._unknownFields);
} }


@@ -2883,6 +3144,11 @@ namespace Tensorflow {
if (ExecutorType.Length != 0) hash ^= ExecutorType.GetHashCode(); if (ExecutorType.Length != 0) hash ^= ExecutorType.GetHashCode();
if (RecvBufMaxChunk != 0) hash ^= RecvBufMaxChunk.GetHashCode(); if (RecvBufMaxChunk != 0) hash ^= RecvBufMaxChunk.GetHashCode();
if (UseNumaAffinity != false) hash ^= UseNumaAffinity.GetHashCode(); if (UseNumaAffinity != false) hash ^= UseNumaAffinity.GetHashCode();
if (CollectiveDeterministicSequentialExecution != false) hash ^= CollectiveDeterministicSequentialExecution.GetHashCode();
if (CollectiveNccl != false) hash ^= CollectiveNccl.GetHashCode();
if (ShareSessionStateInClusterspecPropagation != false) hash ^= ShareSessionStateInClusterspecPropagation.GetHashCode();
if (DisableThreadSpinning != false) hash ^= DisableThreadSpinning.GetHashCode();
if (ShareClusterDevicesInSession != false) hash ^= ShareClusterDevicesInSession.GetHashCode();
if (_unknownFields != null) { if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode(); hash ^= _unknownFields.GetHashCode();
} }
@@ -2912,6 +3178,26 @@ namespace Tensorflow {
output.WriteRawTag(40); output.WriteRawTag(40);
output.WriteBool(UseNumaAffinity); output.WriteBool(UseNumaAffinity);
} }
if (CollectiveDeterministicSequentialExecution != false) {
output.WriteRawTag(48);
output.WriteBool(CollectiveDeterministicSequentialExecution);
}
if (CollectiveNccl != false) {
output.WriteRawTag(56);
output.WriteBool(CollectiveNccl);
}
if (ShareSessionStateInClusterspecPropagation != false) {
output.WriteRawTag(64);
output.WriteBool(ShareSessionStateInClusterspecPropagation);
}
if (DisableThreadSpinning != false) {
output.WriteRawTag(72);
output.WriteBool(DisableThreadSpinning);
}
if (ShareClusterDevicesInSession != false) {
output.WriteRawTag(80);
output.WriteBool(ShareClusterDevicesInSession);
}
if (_unknownFields != null) { if (_unknownFields != null) {
_unknownFields.WriteTo(output); _unknownFields.WriteTo(output);
} }
@@ -2932,6 +3218,21 @@ namespace Tensorflow {
if (UseNumaAffinity != false) { if (UseNumaAffinity != false) {
size += 1 + 1; size += 1 + 1;
} }
if (CollectiveDeterministicSequentialExecution != false) {
size += 1 + 1;
}
if (CollectiveNccl != false) {
size += 1 + 1;
}
if (ShareSessionStateInClusterspecPropagation != false) {
size += 1 + 1;
}
if (DisableThreadSpinning != false) {
size += 1 + 1;
}
if (ShareClusterDevicesInSession != false) {
size += 1 + 1;
}
if (_unknownFields != null) { if (_unknownFields != null) {
size += _unknownFields.CalculateSize(); size += _unknownFields.CalculateSize();
} }
@@ -2955,6 +3256,21 @@ namespace Tensorflow {
if (other.UseNumaAffinity != false) { if (other.UseNumaAffinity != false) {
UseNumaAffinity = other.UseNumaAffinity; UseNumaAffinity = other.UseNumaAffinity;
} }
if (other.CollectiveDeterministicSequentialExecution != false) {
CollectiveDeterministicSequentialExecution = other.CollectiveDeterministicSequentialExecution;
}
if (other.CollectiveNccl != false) {
CollectiveNccl = other.CollectiveNccl;
}
if (other.ShareSessionStateInClusterspecPropagation != false) {
ShareSessionStateInClusterspecPropagation = other.ShareSessionStateInClusterspecPropagation;
}
if (other.DisableThreadSpinning != false) {
DisableThreadSpinning = other.DisableThreadSpinning;
}
if (other.ShareClusterDevicesInSession != false) {
ShareClusterDevicesInSession = other.ShareClusterDevicesInSession;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
} }


@@ -2982,6 +3298,26 @@ namespace Tensorflow {
UseNumaAffinity = input.ReadBool(); UseNumaAffinity = input.ReadBool();
break; break;
} }
case 48: {
CollectiveDeterministicSequentialExecution = input.ReadBool();
break;
}
case 56: {
CollectiveNccl = input.ReadBool();
break;
}
case 64: {
ShareSessionStateInClusterspecPropagation = input.ReadBool();
break;
}
case 72: {
DisableThreadSpinning = input.ReadBool();
break;
}
case 80: {
ShareClusterDevicesInSession = input.ReadBool();
break;
}
} }
} }
} }
@@ -3553,6 +3889,7 @@ namespace Tensorflow {
stepStats_ = other.stepStats_ != null ? other.stepStats_.Clone() : null; stepStats_ = other.stepStats_ != null ? other.stepStats_.Clone() : null;
costGraph_ = other.costGraph_ != null ? other.costGraph_.Clone() : null; costGraph_ = other.costGraph_ != null ? other.costGraph_.Clone() : null;
partitionGraphs_ = other.partitionGraphs_.Clone(); partitionGraphs_ = other.partitionGraphs_.Clone();
functionGraphs_ = other.functionGraphs_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
} }


@@ -3604,6 +3941,28 @@ namespace Tensorflow {
get { return partitionGraphs_; } get { return partitionGraphs_; }
} }


/// <summary>Field number for the "function_graphs" field.</summary>
public const int FunctionGraphsFieldNumber = 4;
private static readonly pb::FieldCodec<global::Tensorflow.RunMetadata.Types.FunctionGraphs> _repeated_functionGraphs_codec
= pb::FieldCodec.ForMessage(34, global::Tensorflow.RunMetadata.Types.FunctionGraphs.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.RunMetadata.Types.FunctionGraphs> functionGraphs_ = new pbc::RepeatedField<global::Tensorflow.RunMetadata.Types.FunctionGraphs>();
/// <summary>
/// This is only populated for graphs that are run as functions in TensorFlow
/// V2. There will be an entry below for each function that is traced.
/// The main use cases of the post_optimization_graph and the partition_graphs
/// is to give the caller insight into the graphs that were actually run by the
/// runtime. Additional information (such as those in step_stats) will match
/// these graphs.
/// We also include the pre_optimization_graph since it is usually easier to
/// read, and is helpful in situations where the caller wants to get a high
/// level idea of what the built graph looks like (since the various graph
/// optimization passes might change the structure of the graph significantly).
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.RunMetadata.Types.FunctionGraphs> FunctionGraphs {
get { return functionGraphs_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) { public override bool Equals(object other) {
return Equals(other as RunMetadata); return Equals(other as RunMetadata);
@@ -3620,6 +3979,7 @@ namespace Tensorflow {
if (!object.Equals(StepStats, other.StepStats)) return false; if (!object.Equals(StepStats, other.StepStats)) return false;
if (!object.Equals(CostGraph, other.CostGraph)) return false; if (!object.Equals(CostGraph, other.CostGraph)) return false;
if(!partitionGraphs_.Equals(other.partitionGraphs_)) return false; if(!partitionGraphs_.Equals(other.partitionGraphs_)) return false;
if(!functionGraphs_.Equals(other.functionGraphs_)) return false;
return Equals(_unknownFields, other._unknownFields); return Equals(_unknownFields, other._unknownFields);
} }


@@ -3629,6 +3989,7 @@ namespace Tensorflow {
if (stepStats_ != null) hash ^= StepStats.GetHashCode(); if (stepStats_ != null) hash ^= StepStats.GetHashCode();
if (costGraph_ != null) hash ^= CostGraph.GetHashCode(); if (costGraph_ != null) hash ^= CostGraph.GetHashCode();
hash ^= partitionGraphs_.GetHashCode(); hash ^= partitionGraphs_.GetHashCode();
hash ^= functionGraphs_.GetHashCode();
if (_unknownFields != null) { if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode(); hash ^= _unknownFields.GetHashCode();
} }
@@ -3651,6 +4012,7 @@ namespace Tensorflow {
output.WriteMessage(CostGraph); output.WriteMessage(CostGraph);
} }
partitionGraphs_.WriteTo(output, _repeated_partitionGraphs_codec); partitionGraphs_.WriteTo(output, _repeated_partitionGraphs_codec);
functionGraphs_.WriteTo(output, _repeated_functionGraphs_codec);
if (_unknownFields != null) { if (_unknownFields != null) {
_unknownFields.WriteTo(output); _unknownFields.WriteTo(output);
} }
@@ -3666,6 +4028,7 @@ namespace Tensorflow {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(CostGraph); size += 1 + pb::CodedOutputStream.ComputeMessageSize(CostGraph);
} }
size += partitionGraphs_.CalculateSize(_repeated_partitionGraphs_codec); size += partitionGraphs_.CalculateSize(_repeated_partitionGraphs_codec);
size += functionGraphs_.CalculateSize(_repeated_functionGraphs_codec);
if (_unknownFields != null) { if (_unknownFields != null) {
size += _unknownFields.CalculateSize(); size += _unknownFields.CalculateSize();
} }
@@ -3690,6 +4053,7 @@ namespace Tensorflow {
CostGraph.MergeFrom(other.CostGraph); CostGraph.MergeFrom(other.CostGraph);
} }
partitionGraphs_.Add(other.partitionGraphs_); partitionGraphs_.Add(other.partitionGraphs_);
functionGraphs_.Add(other.functionGraphs_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
} }


@@ -3719,9 +4083,212 @@ namespace Tensorflow {
partitionGraphs_.AddEntriesFrom(input, _repeated_partitionGraphs_codec); partitionGraphs_.AddEntriesFrom(input, _repeated_partitionGraphs_codec);
break; break;
} }
case 34: {
functionGraphs_.AddEntriesFrom(input, _repeated_functionGraphs_codec);
break;
}
}
}
}

#region Nested types
/// <summary>Container for nested types declared in the RunMetadata message type.</summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static partial class Types {
public sealed partial class FunctionGraphs : pb::IMessage<FunctionGraphs> {
private static readonly pb::MessageParser<FunctionGraphs> _parser = new pb::MessageParser<FunctionGraphs>(() => new FunctionGraphs());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<FunctionGraphs> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.RunMetadata.Descriptor.NestedTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public FunctionGraphs() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public FunctionGraphs(FunctionGraphs other) : this() {
partitionGraphs_ = other.partitionGraphs_.Clone();
preOptimizationGraph_ = other.preOptimizationGraph_ != null ? other.preOptimizationGraph_.Clone() : null;
postOptimizationGraph_ = other.postOptimizationGraph_ != null ? other.postOptimizationGraph_.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public FunctionGraphs Clone() {
return new FunctionGraphs(this);
}

/// <summary>Field number for the "partition_graphs" field.</summary>
public const int PartitionGraphsFieldNumber = 1;
private static readonly pb::FieldCodec<global::Tensorflow.GraphDef> _repeated_partitionGraphs_codec
= pb::FieldCodec.ForMessage(10, global::Tensorflow.GraphDef.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.GraphDef> partitionGraphs_ = new pbc::RepeatedField<global::Tensorflow.GraphDef>();
/// <summary>
/// TODO(nareshmodi): Include some sort of function/cache-key identifier?
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.GraphDef> PartitionGraphs {
get { return partitionGraphs_; }
}

/// <summary>Field number for the "pre_optimization_graph" field.</summary>
public const int PreOptimizationGraphFieldNumber = 2;
private global::Tensorflow.GraphDef preOptimizationGraph_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Tensorflow.GraphDef PreOptimizationGraph {
get { return preOptimizationGraph_; }
set {
preOptimizationGraph_ = value;
}
}

/// <summary>Field number for the "post_optimization_graph" field.</summary>
public const int PostOptimizationGraphFieldNumber = 3;
private global::Tensorflow.GraphDef postOptimizationGraph_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Tensorflow.GraphDef PostOptimizationGraph {
get { return postOptimizationGraph_; }
set {
postOptimizationGraph_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as FunctionGraphs);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(FunctionGraphs other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if(!partitionGraphs_.Equals(other.partitionGraphs_)) return false;
if (!object.Equals(PreOptimizationGraph, other.PreOptimizationGraph)) return false;
if (!object.Equals(PostOptimizationGraph, other.PostOptimizationGraph)) return false;
return Equals(_unknownFields, other._unknownFields);
} }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
hash ^= partitionGraphs_.GetHashCode();
if (preOptimizationGraph_ != null) hash ^= PreOptimizationGraph.GetHashCode();
if (postOptimizationGraph_ != null) hash ^= PostOptimizationGraph.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
partitionGraphs_.WriteTo(output, _repeated_partitionGraphs_codec);
if (preOptimizationGraph_ != null) {
output.WriteRawTag(18);
output.WriteMessage(PreOptimizationGraph);
}
if (postOptimizationGraph_ != null) {
output.WriteRawTag(26);
output.WriteMessage(PostOptimizationGraph);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
size += partitionGraphs_.CalculateSize(_repeated_partitionGraphs_codec);
if (preOptimizationGraph_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(PreOptimizationGraph);
}
if (postOptimizationGraph_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(PostOptimizationGraph);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(FunctionGraphs other) {
if (other == null) {
return;
}
partitionGraphs_.Add(other.partitionGraphs_);
if (other.preOptimizationGraph_ != null) {
if (preOptimizationGraph_ == null) {
preOptimizationGraph_ = new global::Tensorflow.GraphDef();
}
PreOptimizationGraph.MergeFrom(other.PreOptimizationGraph);
}
if (other.postOptimizationGraph_ != null) {
if (postOptimizationGraph_ == null) {
postOptimizationGraph_ = new global::Tensorflow.GraphDef();
}
PostOptimizationGraph.MergeFrom(other.PostOptimizationGraph);
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
partitionGraphs_.AddEntriesFrom(input, _repeated_partitionGraphs_codec);
break;
}
case 18: {
if (preOptimizationGraph_ == null) {
preOptimizationGraph_ = new global::Tensorflow.GraphDef();
}
input.ReadMessage(preOptimizationGraph_);
break;
}
case 26: {
if (postOptimizationGraph_ == null) {
postOptimizationGraph_ = new global::Tensorflow.GraphDef();
}
input.ReadMessage(postOptimizationGraph_);
break;
}
}
}
}

} }

} }
#endregion


} }




+ 8
- 7
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -36,19 +36,20 @@ namespace Tensorflow
protected byte[] _target; protected byte[] _target;
public Graph graph => _graph; public Graph graph => _graph;


public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null)
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
{ {
_graph = g ?? ops.get_default_graph(); _graph = g ?? ops.get_default_graph();
_graph.as_default(); _graph.as_default();
_target = Encoding.UTF8.GetBytes(target); _target = Encoding.UTF8.GetBytes(target);


SessionOptions lopts = opts ?? new SessionOptions();

lock (Locks.ProcessWide)
using (var opts = new SessionOptions(target, config))
{ {
status = status ?? new Status();
_handle = c_api.TF_NewSession(_graph, opts ?? lopts, status);
status.Check(true);
lock (Locks.ProcessWide)
{
status = status ?? new Status();
_handle = c_api.TF_NewSession(_graph, opts, status);
status.Check(true);
}
} }
} }




+ 1
- 1
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -32,7 +32,7 @@ namespace Tensorflow
_handle = handle; _handle = handle;
} }


public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s)
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s)
{ } { }


public Session as_default() public Session as_default()


+ 8
- 5
src/TensorFlowNET.Core/Sessions/SessionOptions.cs View File

@@ -20,11 +20,14 @@ using System.Runtime.InteropServices;


namespace Tensorflow namespace Tensorflow
{ {
public class SessionOptions : DisposableObject
internal class SessionOptions : DisposableObject
{ {
public SessionOptions()
public SessionOptions(string target = "", ConfigProto config = null)
{ {
_handle = c_api.TF_NewSessionOptions(); _handle = c_api.TF_NewSessionOptions();
c_api.TF_SetTarget(_handle, target);
if (config != null)
SetConfig(config);
} }


public SessionOptions(IntPtr handle) public SessionOptions(IntPtr handle)
@@ -35,10 +38,10 @@ namespace Tensorflow
protected override void DisposeUnmanagedResources(IntPtr handle) protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteSessionOptions(handle); => c_api.TF_DeleteSessionOptions(handle);


public void SetConfig(ConfigProto config)
private void SetConfig(ConfigProto config)
{ {
var bytes = config.ToByteArray(); //TODO! we can use WriteTo
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak
var bytes = config.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length);


using (var status = new Status()) using (var status = new Status())


+ 0
- 10
src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs View File

@@ -1,10 +0,0 @@
using System.Runtime.InteropServices;

namespace Tensorflow
{
[StructLayout(LayoutKind.Sequential)]
public struct TF_SessionOptions
{
public SessionOptions options;
}
}

+ 2
- 2
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -28,9 +28,9 @@ namespace Tensorflow
{ {
private Func<List<NDArray>, object> _contraction_fn; private Func<List<NDArray>, object> _contraction_fn;


public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn)
public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn, Graph graph = null)
{ {
var g = ops.get_default_graph();
var g = graph ?? ops.get_default_graph();


foreach(var fetch in fetches) foreach(var fetch in fetches)
{ {


+ 1
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -34,7 +34,7 @@ namespace Tensorflow


public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
{ {
_fetch_mapper = _FetchMapper.for_fetch(fetches);
_fetch_mapper = _FetchMapper.for_fetch(fetches, graph: graph);
foreach(var fetch in _fetch_mapper.unique_fetches()) foreach(var fetch in _fetch_mapper.unique_fetches())
{ {
switch (fetch) switch (fetch)


+ 2
- 2
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow
{ {
protected List<ITensorOrOperation> _unique_fetches = new List<ITensorOrOperation>(); protected List<ITensorOrOperation> _unique_fetches = new List<ITensorOrOperation>();
protected List<int[]> _value_indices = new List<int[]>(); protected List<int[]> _value_indices = new List<int[]>();
public static _FetchMapper for_fetch(object fetch)
public static _FetchMapper for_fetch(object fetch, Graph graph = null)
{ {
var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch };


@@ -34,7 +34,7 @@ namespace Tensorflow
if (fetch.GetType().IsArray) if (fetch.GetType().IsArray)
return new _ListFetchMapper(fetches); return new _ListFetchMapper(fetches);
else else
return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0]);
return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0], graph: graph);
} }


public virtual NDArray[] build_results(List<NDArray> values) public virtual NDArray[] build_results(List<NDArray> values)


+ 4
- 1
src/TensorFlowNET.Core/Sessions/c_api.session.cs View File

@@ -116,6 +116,9 @@ namespace Tensorflow
/// <param name="proto_len">size_t</param> /// <param name="proto_len">size_t</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status);
public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetTarget(IntPtr options, string target);
} }
} }

+ 11
- 22
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.0</TargetTensorFlow> <TargetTensorFlow>1.14.0</TargetTensorFlow>
<Version>0.11.8</Version>
<Version>0.12.0</Version>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -16,25 +16,15 @@
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl> <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
<Description>Google's TensorFlow full binding in .NET Standard. <Description>Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.11.8.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.10.0:
1. Upgrade NumSharp to v0.20.3.
2. Add DisposableObject class to manage object lifetime.
3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables.
4. Change tensorflow to non-static class in order to execute some initialization process.
5. Overload session.run(), make syntax simpler.
6. Add Local Response Normalization.
7. Add tf.image related APIs.
8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor.
9. MultiThread is safe.
10. Support n-dim indexing for tensor.
11. Add RegisterNoGradients
12. Add CumsumGrad, BroadcastToGrad.
13. Return VariableV1 instead of RefVariable.
14. Add Tensor overload to GradientDescentOptimizer.</PackageReleaseNotes>
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.12.0.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.11.0:
1: Add ICanBeFlattened for nest.flatten2.
2: Complete the WhileContext.
3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn.</PackageReleaseNotes>
<LangVersion>7.3</LangVersion> <LangVersion>7.3</LangVersion>
<FileVersion>0.11.8.0</FileVersion>
<FileVersion>0.12.0.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>
@@ -43,7 +33,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>


<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG;SERIALIZABLE</DefineConstants>
<DefineConstants>TRACE;DEBUG;SERIALIZABLE_</DefineConstants>
</PropertyGroup> </PropertyGroup>


<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
@@ -65,8 +55,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.5.1" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" />
<PackageReference Include="Google.Protobuf" Version="3.10.0" />
<PackageReference Include="NumSharp" Version="0.20.4" /> <PackageReference Include="NumSharp" Version="0.20.4" />
</ItemGroup> </ItemGroup>




+ 2
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -25,7 +25,9 @@ using System.Text;
using NumSharp.Backends; using NumSharp.Backends;
using NumSharp.Backends.Unmanaged; using NumSharp.Backends.Unmanaged;
using static Tensorflow.c_api; using static Tensorflow.c_api;
#if SERIALIZABLE
using Newtonsoft.Json; using Newtonsoft.Json;
#endif


namespace Tensorflow namespace Tensorflow
{ {


+ 15
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs View File

@@ -0,0 +1,15 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public partial class Tensor
{
public object[] Flatten()
{
return new Tensor[] { this };
}
}
}

+ 15
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Pack.cs View File

@@ -0,0 +1,15 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public partial class Tensor
{
public Tensor Pack(object[] sequences)
{
return sequences[0] as Tensor;
}
}
}

+ 10
- 3
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -28,7 +28,9 @@ using NumSharp.Backends;
using NumSharp.Backends.Unmanaged; using NumSharp.Backends.Unmanaged;
using NumSharp.Utilities; using NumSharp.Utilities;
using Tensorflow.Framework; using Tensorflow.Framework;
#if SERIALIZABLE
using Newtonsoft.Json; using Newtonsoft.Json;
#endif


namespace Tensorflow namespace Tensorflow
{ {
@@ -37,7 +39,12 @@ namespace Tensorflow
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
/// </summary> /// </summary>
[SuppressMessage("ReSharper", "ConvertToAutoProperty")] [SuppressMessage("ReSharper", "ConvertToAutoProperty")]
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike
public partial class Tensor : DisposableObject,
ITensorOrOperation,
_TensorLike,
ITensorOrTensorArray,
IPackable<Tensor>,
ICanBeFlattened
{ {
private readonly int _id; private readonly int _id;
private readonly Operation _op; private readonly Operation _op;
@@ -95,7 +102,7 @@ namespace Tensorflow
[JsonIgnore] [JsonIgnore]
#endif #endif
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
private IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
#if SERIALIZABLE #if SERIALIZABLE
[JsonIgnore] [JsonIgnore]
@@ -176,7 +183,7 @@ namespace Tensorflow
/// </summary> /// </summary>
public void set_shape(TensorShape shape) public void set_shape(TensorShape shape)
{ {
this.shape = shape.rank > 0 ? shape.dims : null;
this.shape = shape.rank >= 0 ? shape.dims : null;
} }


/// <summary> /// <summary>


src/TensorFlowNET.Core/Operations/TensorArray.cs → src/TensorFlowNET.Core/Tensors/TensorArray.cs View File

@@ -17,8 +17,9 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Operations;


namespace Tensorflow.Operations
namespace Tensorflow
{ {
/// <summary> /// <summary>
/// TensorArray is designed to hide an underlying implementation object /// TensorArray is designed to hide an underlying implementation object
@@ -29,9 +30,9 @@ namespace Tensorflow.Operations
/// `while_loop` and `map_fn`. It supports gradient back-propagation via special /// `while_loop` and `map_fn`. It supports gradient back-propagation via special
/// "flow" control flow dependencies. /// "flow" control flow dependencies.
/// </summary> /// </summary>
public class TensorArray
public class TensorArray : ITensorOrTensorArray
{ {
_GraphTensorArray _implementation;
internal _GraphTensorArray _implementation;


public TF_DataType dtype => _implementation._dtype; public TF_DataType dtype => _implementation._dtype;
public Tensor handle => _implementation._handle; public Tensor handle => _implementation._handle;
@@ -39,7 +40,7 @@ namespace Tensorflow.Operations


public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null,
string tensor_array_name = null, Tensor handle = null, Tensor flow = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, TensorShape[] element_shape = null,
bool infer_shape = true, TensorShape element_shape = null,
bool colocate_with_first_write_call = true, string name = null) bool colocate_with_first_write_call = true, string name = null)
{ {
_implementation = new _GraphTensorArray(dtype, _implementation = new _GraphTensorArray(dtype,
@@ -57,5 +58,14 @@ namespace Tensorflow.Operations


public TensorArray unstack(Tensor value, string name = null) public TensorArray unstack(Tensor value, string name = null)
=> _implementation.unstack(value, name: name); => _implementation.unstack(value, name: name);

public Tensor read(Tensor index, string name = null)
=> _implementation.read(index, name: name);

public TensorArray write(Tensor index, Tensor value, string name = null)
=> _implementation.write(index, value, name: name);

public Tensor stack(string name = null)
=> _implementation.stack(name: name);
} }
} }

+ 7
- 2
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -1,10 +1,12 @@
using Newtonsoft.Json;
using NumSharp;
using NumSharp;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
#if SERIALIZABLE
using Newtonsoft.Json;
#endif
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -123,6 +125,9 @@ namespace Tensorflow
{ {
get get
{ {
if (!slice.Stop.HasValue)
slice.Stop = dims.Length - slice.Start + 1;

if (slice.Start.HasValue == false || slice.Length.HasValue == false) if (slice.Start.HasValue == false || slice.Length.HasValue == false)
throw new ArgumentException("Slice must has Start and Length."); throw new ArgumentException("Slice must has Start and Length.");




+ 3
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -33,6 +33,9 @@ namespace Tensorflow
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE; public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
public static TF_DataType complex = TF_DataType.TF_COMPLEX;
public static TF_DataType complex64 = TF_DataType.TF_COMPLEX64;
public static TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
public static TF_DataType variant = TF_DataType.TF_VARIANT; public static TF_DataType variant = TF_DataType.TF_VARIANT;
public static TF_DataType resource = TF_DataType.TF_RESOURCE; public static TF_DataType resource = TF_DataType.TF_RESOURCE;




+ 5
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -335,5 +335,10 @@ namespace Tensorflow


return shape; return shape;
} }

public static Tensor shape_tensor(int[] shape)
{
return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape");
}
} }
} }

+ 9
- 2
src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs View File

@@ -35,22 +35,29 @@ namespace Tensorflow.Train
/// for changing these values across different invocations of optimizer /// for changing these values across different invocations of optimizer
/// functions. /// functions.
/// </remarks> /// </remarks>
private bool _useTensor;
public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent")
: base(learning_rate, use_locking, name) : base(learning_rate, use_locking, name)
{ {
_lr = learning_rate; _lr = learning_rate;
_useTensor = false;
} }
public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent")
: base(learning_rate, use_locking, name) : base(learning_rate, use_locking, name)
{ {
_lr_t = learning_rate; _lr_t = learning_rate;
_useTensor = true;
} }


public override void _prepare() public override void _prepare()
{ {
var lr = _call_if_callable(_lr);
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
if(!_useTensor)
{
var lr = _call_if_callable(_lr);
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
}
} }
} }
} }

+ 20
- 2
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -19,6 +19,7 @@ using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using NumSharp; using NumSharp;
using Tensorflow.Operations;
namespace Tensorflow.Util namespace Tensorflow.Util
{ {
@@ -221,9 +222,14 @@ namespace Tensorflow.Util
return list; return list;
} }
public static object[] flatten2(ICanBeFlattened structure)
=> structure.Flatten();
public static T[] flatten2<T>(T[] structure)
=> structure;
private static void _flatten_recursive<T>(T obj, List<T> list) private static void _flatten_recursive<T>(T obj, List<T> list)
{ {
switch(obj) switch(obj)
{ {
case IDictionary dict: case IDictionary dict:
@@ -395,6 +401,10 @@ namespace Tensorflow.Util
private static int len(IEnumerable<object> x) => x.Count(); private static int len(IEnumerable<object> x) => x.Count();
public static T pack_sequence_as2<T>(T structure, object[] flat_sequence, bool expand_composites = false)
where T : IPackable<T>
=> structure.Pack(flat_sequence);
/// <summary> /// <summary>
/// Returns a given flattened sequence packed into a given structure. /// Returns a given flattened sequence packed into a given structure.
/// If `structure` is a scalar, `flat_sequence` must be a single-element list; /// If `structure` is a scalar, `flat_sequence` must be a single-element list;
@@ -418,7 +428,7 @@ namespace Tensorflow.Util
/// <returns> `flat_sequence` converted to have the same recursive structure as /// <returns> `flat_sequence` converted to have the same recursive structure as
/// `structure`. /// `structure`.
/// </returns> /// </returns>
public static object pack_sequence_as(object structure, IEnumerable<object> flat_sequence)
public static object pack_sequence_as(object structure, IEnumerable<object> flat_sequence, bool expand_composites = false)
{ {
List<object> flat = null; List<object> flat = null;
if (flat_sequence is List<object>) if (flat_sequence is List<object>)
@@ -516,6 +526,14 @@ namespace Tensorflow.Util
return pack_sequence_as(structure, mapped_flat_structure) as Tensor; return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
} }
public static Tensor map_structure2<T>(Func<T, Tensor> func, T structure)
{
var flat_structure = flatten(structure);
var mapped_flat_structure = flat_structure.Select(func).ToList();
return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
}
/// <summary> /// <summary>
/// Same as map_structure, but with only one structure (no combining of multiple structures) /// Same as map_structure, but with only one structure (no combining of multiple structures)
/// </summary> /// </summary>


+ 51
- 48
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -133,66 +133,69 @@ namespace Tensorflow
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);


ops.init_scope();
var values = init_from_fn ? new object[0] : new object[] { initial_value };
tf_with(ops.name_scope(name, "Variable", values), scope =>
tf_with(ops.init_scope2(), delegate
{ {
name = scope;
if (init_from_fn)
var values = init_from_fn ? new object[0] : new object[] { initial_value };
tf_with(ops.name_scope(name, "Variable", values), scope =>
{ {
// Use attr_scope and device(None) to simulate the behavior of
// colocate_with when the variable we want to colocate with doesn't
// yet exist.
string true_name = ops.name_from_scope_name(name);
var attr = new AttrValue
name = scope;

if (init_from_fn)
{ {
List = new AttrValue.Types.ListValue()
};
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}"));
tf_with(ops.name_scope("Initializer"), scope2 =>
// Use attr_scope and device(None) to simulate the behavior of
// colocate_with when the variable we want to colocate with doesn't
// yet exist.
string true_name = ops.name_from_scope_name(name);
var attr = new AttrValue
{
List = new AttrValue.Types.ListValue()
};
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}"));
tf_with(ops.name_scope("Initializer"), scope2 =>
{
_initial_value = (initial_value as Func<Tensor>)();
_initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
});
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
}
// Or get the initial value from a Tensor or Python object.
else
{ {
_initial_value = (initial_value as Func<Tensor>)();
_initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
});
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
}
// Or get the initial value from a Tensor or Python object.
else
{
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype);
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype);


var shape = _initial_value.shape;
dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
}
var shape = _initial_value.shape;
dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
}


// Manually overrides the variable's shape with the initial value's.
if (validate_shape)
{
var initial_value_shape = _initial_value.TensorShape;
if (!initial_value_shape.is_fully_defined())
throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
}
// Manually overrides the variable's shape with the initial value's.
if (validate_shape)
{
var initial_value_shape = _initial_value.TensorShape;
if (!initial_value_shape.is_fully_defined())
throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
}


// If 'initial_value' makes use of other variables, make sure we don't
// have an issue if these other variables aren't initialized first by
// using their initialized_value() method.
var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value);
// If 'initial_value' makes use of other variables, make sure we don't
// have an issue if these other variables aren't initialized first by
// using their initialized_value() method.
var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value);


_initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;
_initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;


if (!String.IsNullOrEmpty(caching_device))
{
if (!String.IsNullOrEmpty(caching_device))
{


}
else
{
ops.colocate_with(_initializer_op);
}
else
{
ops.colocate_with(_initializer_op);


_snapshot = gen_array_ops.identity(_variable, name = "read");
}
_snapshot = gen_array_ops.identity(_variable, name = "read");
}


ops.add_to_collections(collections, this as VariableV1);
ops.add_to_collections(collections, this as VariableV1);
});
}); });
} }




+ 22
- 9
src/TensorFlowNET.Core/ops.cs View File

@@ -186,12 +186,7 @@ namespace Tensorflow
/// operations constructed within the context. /// operations constructed within the context.
/// </returns> /// </returns>
public static _ControlDependenciesController control_dependencies(object[] control_inputs) public static _ControlDependenciesController control_dependencies(object[] control_inputs)
{
return get_default_graph().control_dependencies(control_inputs);
}

public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
=> control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray());
=> get_default_graph().control_dependencies(control_inputs);


/// <summary> /// <summary>
/// Creates a TF_Operation. /// Creates a TF_Operation.
@@ -212,9 +207,9 @@ namespace Tensorflow
{ {
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); var op_desc = graph.NewOperation(node_def.Op, node_def.Name);


//TODO: Implement TF_SetDevice
//if node_def.device:
// c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device))
if (!string.IsNullOrEmpty(node_def.Device))
c_api.TF_SetDevice(op_desc, node_def.Device);
// Add inputs // Add inputs
foreach (var op_input in inputs) foreach (var op_input in inputs)
{ {
@@ -310,6 +305,22 @@ namespace Tensorflow
}); });
} }


public static IObjectLife init_scope2()
{
// Retrieve the active name scope: entering an `init_scope` preserves
// the name scope of the current context.
var default_graph = get_default_graph();
var scope = default_graph.get_name_scope();
if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
// Names that end with trailing slashes are treated by `name_scope` as
// absolute.
scope += "/";
// inner_device_stack = default_graph._device_function_stack
// var outer_context = default_graph.as_default;

return ops.control_dependencies(null);
}

private static int uid_number = 0; private static int uid_number = 0;


/// <summary> /// <summary>
@@ -508,6 +519,8 @@ namespace Tensorflow
return null; return null;
case TensorShape ts: case TensorShape ts:
return constant_op.constant(ts.dims, dtype: dtype, name: name); return constant_op.constant(ts.dims, dtype: dtype, name: name);
case int[] dims:
return constant_op.constant(dims, dtype: dtype, name: name);
case object[] objects: case object[] objects:
return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name); return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name);
default: default:


+ 3
- 0
src/TensorFlowNET.Core/ops.name_scope.cs View File

@@ -45,7 +45,10 @@ namespace Tensorflow
public void __enter__() public void __enter__()
{ {
_name = _name ?? _default_name; _name = _name ?? _default_name;
if (_name.EndsWith("basic_r_n_n_cell"))
{


}
Graph g = null; Graph g = null;


if (_values is List<Tensor> vList) if (_values is List<Tensor> vList)


+ 4
- 4
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -93,14 +93,14 @@ namespace Tensorflow
return new Session().as_default(); return new Session().as_default();
} }


public Session Session(Graph graph, SessionOptions opts = null)
public Session Session(Graph graph, ConfigProto config = null)
{ {
return new Session(graph, opts: opts).as_default();
return new Session(graph, config: config).as_default();
} }


public Session Session(SessionOptions opts)
public Session Session(ConfigProto config)
{ {
return new Session(null, opts).as_default();
return new Session(null, config).as_default();
} }


public void __init__() public void __init__()


+ 2
- 3
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -25,9 +25,8 @@ namespace TensorFlowNET.UnitTest
{ {
lock (Locks.ProcessWide) lock (Locks.ProcessWide)
{ {
var opts = new SessionOptions();
opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4});
session_ = new Session(graph, opts, s);
var config = new ConfigProto {InterOpParallelismThreads = 4};
session_ = new Session(graph, config, s);
} }
} }




+ 3
- 10
test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs View File

@@ -18,10 +18,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var i = constant_op.constant(0, name: "i"); var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c"));
var r = control_flow_ops.while_loop(c, b, new[] { i });
var r = control_flow_ops.while_loop(c, b, i);
} }
private void _testWhileContextHelper(int? maximum_iterations = null)
private void _testWhileContextHelper(int maximum_iterations)
{ {
// TODO: implement missing code dependencies // TODO: implement missing code dependencies
using (var sess = this.cached_session()) using (var sess = this.cached_session())
@@ -30,7 +30,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c"));
control_flow_ops.while_loop( control_flow_ops.while_loop(
c, b, new[] { i }, maximum_iterations: maximum_iterations);
c, b, i , maximum_iterations: tf.constant(maximum_iterations));
foreach (Operation op in sess.graph.get_operations()) foreach (Operation op in sess.graph.get_operations())
{ {
var control_flow_context = op._get_control_flow_context(); var control_flow_context = op._get_control_flow_context();
@@ -42,13 +42,6 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
} }
} }
[Ignore("TODO")]
[TestMethod]
public void testWhileContext()
{
_testWhileContextHelper();
}
[Ignore("TODO")] [Ignore("TODO")]
[TestMethod] [TestMethod]
public void testWhileContextWithMaximumIterations() public void testWhileContextWithMaximumIterations()


Loading…
Cancel
Save