From e9d2e90280d82d6933634b1ec7bfb15399b8d361 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 18 Oct 2019 06:34:10 -0500 Subject: [PATCH 01/41] SERIALIZABLE_ --- src/TensorFlowNET.Core/Operations/Operation.Input.cs | 4 +++- src/TensorFlowNET.Core/Operations/Operation.Output.cs | 4 +++- src/TensorFlowNET.Core/Operations/Operation.cs | 6 +++--- src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 5 ++--- src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs | 4 +++- src/TensorFlowNET.Core/Tensors/Tensor.cs | 4 +++- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 6 ++++-- 7 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index f518c726..7d905f3e 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -14,10 +14,12 @@ limitations under the License. ******************************************************************************/ -using Newtonsoft.Json; using System; using System.Linq; using System.Runtime.InteropServices; +#if SERIALIZABLE +using System.Text.Json.Serialization; +#endif namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index f4dcdfd6..e68513b5 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -14,10 +14,12 @@ limitations under the License. ******************************************************************************/ -using Newtonsoft.Json; using System; using System.Linq; using System.Runtime.InteropServices; +#if SERIALIZABLE +using System.Text.Json.Serialization; +#endif using static Tensorflow.Binding; namespace Tensorflow diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 2e653e51..c270afc3 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -16,12 +16,12 @@ using Google.Protobuf.Collections; #if SERIALIZABLE -using Newtonsoft.Json; +using System.Text.Json.Serialization; #endif using System; using System.Collections.Generic; using System.IO; -using System.Linq; +using System.Linq; using Tensorflow.Util; namespace Tensorflow @@ -65,7 +65,7 @@ namespace Tensorflow #if SERIALIZABLE [JsonIgnore] #endif - public int _id_value; + public int _id_value { get; set; } #if SERIALIZABLE [JsonIgnore] #endif diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index bb5d889c..f7b7d424 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -43,7 +43,7 @@ Docs: https://tensorflownet.readthedocs.io true - TRACE;DEBUG;SERIALIZABLE + TRACE;DEBUG;SERIALIZABLE_ @@ -65,8 +65,7 @@ Docs: https://tensorflownet.readthedocs.io - - + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 7c5054d3..9b9e31da 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -25,7 +25,9 @@ using System.Text; using NumSharp.Backends; using NumSharp.Backends.Unmanaged; using static Tensorflow.c_api; -using Newtonsoft.Json; +#if SERIALIZABLE +using System.Text.Json.Serialization; +#endif namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index ef053651..485ee4f1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -28,7 +28,9 @@ using NumSharp.Backends; using NumSharp.Backends.Unmanaged; using NumSharp.Utilities; using Tensorflow.Framework; -using Newtonsoft.Json; +#if SERIALIZABLE +using System.Text.Json.Serialization; +#endif namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 80bb31c1..02e57800 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -1,10 +1,12 @@ -using Newtonsoft.Json; -using NumSharp; +using NumSharp; using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; +#if SERIALIZABLE +using System.Text.Json.Serialization; +#endif using static Tensorflow.Binding; namespace Tensorflow From 27ab492cadf583627dfdaad9d8b346aff946708f Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 18 Oct 2019 06:52:14 -0500 Subject: [PATCH 02/41] change to Newtonsoft.Json. --- src/TensorFlowNET.Core/Operations/Operation.Input.cs | 2 +- src/TensorFlowNET.Core/Operations/Operation.Output.cs | 2 +- src/TensorFlowNET.Core/Operations/Operation.cs | 2 +- src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs | 2 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 2 +- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 7d905f3e..57ac8271 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -18,7 +18,7 @@ using System; using System.Linq; using System.Runtime.InteropServices; #if SERIALIZABLE -using System.Text.Json.Serialization; +using Newtonsoft.Json; #endif namespace Tensorflow diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index e68513b5..77bf68a1 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -18,7 +18,7 @@ using System; using System.Linq; using System.Runtime.InteropServices; #if SERIALIZABLE -using System.Text.Json.Serialization; +using Newtonsoft.Json; #endif using static Tensorflow.Binding; diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index c270afc3..3b40c95a 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -16,7 +16,7 @@ using Google.Protobuf.Collections; #if SERIALIZABLE -using System.Text.Json.Serialization; +using Newtonsoft.Json; #endif using System; using System.Collections.Generic; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 9b9e31da..39c272b4 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -26,7 +26,7 @@ using NumSharp.Backends; using NumSharp.Backends.Unmanaged; using static Tensorflow.c_api; #if SERIALIZABLE -using System.Text.Json.Serialization; +using Newtonsoft.Json; #endif namespace Tensorflow diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 485ee4f1..01290644 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -29,7 +29,7 @@ using NumSharp.Backends.Unmanaged; using NumSharp.Utilities; using Tensorflow.Framework; #if SERIALIZABLE -using System.Text.Json.Serialization; +using Newtonsoft.Json; #endif namespace Tensorflow diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 02e57800..bfd90a75 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -5,7 +5,7 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; #if SERIALIZABLE -using System.Text.Json.Serialization; +using Newtonsoft.Json; #endif using static Tensorflow.Binding; From 203b0e2820002090d2ea93b12948b7eaa3d5a041 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 18 Oct 2019 07:14:31 -0500 Subject: [PATCH 03/41] add back Graph.IEnumerable --- src/TensorFlowNET.Core/Graphs/Graph.cs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 9b6906aa..87a1424f 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -75,7 +75,10 @@ namespace Tensorflow /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// /// https://www.tensorflow.org/guide/graphs

https://www.tensorflow.org/api_docs/python/tf/Graph
- public partial class Graph : DisposableObject//, IEnumerable + public partial class Graph : DisposableObject, +#if !SERIALIZABLE + IEnumerable +#endif { private Dictionary _nodes_by_id; public Dictionary _nodes_by_name; @@ -524,17 +527,19 @@ namespace Tensorflow } return debugString;*/ - } - - /*private IEnumerable GetEnumerable() + } + +#if !SERIALIZABLE + private IEnumerable GetEnumerable() => c_api_util.tf_operations(this); IEnumerator IEnumerable.GetEnumerator() => GetEnumerable().GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() - => throw new NotImplementedException();*/ - + => throw new NotImplementedException(); +#endif + public static implicit operator IntPtr(Graph graph) { return graph._handle; From f0ff82a94c90861ddeb0c98b76954e5da40f3f2d Mon Sep 17 00:00:00 2001 From: Harshitha Parnandi Venkata Date: Fri, 18 Oct 2019 17:24:45 -0700 Subject: [PATCH 04/41] Fixed a bug that overrites the learning rate when sent as a Tensor. --- .../Train/GradientDescentOptimizer.cs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs index 1a2821bb..d4682066 100644 --- a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs @@ -35,22 +35,29 @@ namespace Tensorflow.Train /// for changing these values across different invocations of optimizer /// functions. /// + private bool _useTensor; public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") : base(learning_rate, use_locking, name) { _lr = learning_rate; + _useTensor = false; } public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") : base(learning_rate, use_locking, name) { _lr_t = learning_rate; + _useTensor = true; } public override void _prepare() { - var lr = _call_if_callable(_lr); - _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); + if(!_useTensor) + { + var lr = _call_if_callable(_lr); + _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); + } + } } } From 6931d5c42592ee86b610dab1d12fddd2560b8697 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 21 Oct 2019 06:03:22 -0500 Subject: [PATCH 05/41] change Session(ConfigProto). --- .../Sessions/BaseSession.cs | 15 ++++++----- src/TensorFlowNET.Core/Sessions/Session.cs | 2 +- .../Sessions/SessionOptions.cs | 13 ++++++---- .../Sessions/TF_DeprecatedSession.cs | 26 ------------------- .../Sessions/TF_SessionOptions.cs | 10 ------- .../Sessions/c_api.session.cs | 5 +++- 6 files changed, 21 insertions(+), 50 deletions(-) delete mode 100644 src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs delete mode 100644 src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 1701c625..bb37956c 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -36,19 +36,20 @@ namespace Tensorflow protected byte[] _target; public Graph graph => _graph; - public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null) + public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) { _graph = g ?? ops.get_default_graph(); _graph.as_default(); _target = Encoding.UTF8.GetBytes(target); - SessionOptions lopts = opts ?? new SessionOptions(); - - lock (Locks.ProcessWide) + using (var opts = new SessionOptions(target, config)) { - status = status ?? new Status(); - _handle = c_api.TF_NewSession(_graph, opts ?? lopts, status); - status.Check(true); + lock (Locks.ProcessWide) + { + status = status ?? new Status(); + _handle = c_api.TF_NewSession(_graph, opts, status); + status.Check(true); + } } } diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index a89d94dc..caa669d3 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -32,7 +32,7 @@ namespace Tensorflow _handle = handle; } - public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s) + public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) { } public Session as_default() diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs index 112543fe..0e64033c 100644 --- a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs +++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs @@ -20,11 +20,14 @@ using System.Runtime.InteropServices; namespace Tensorflow { - public class SessionOptions : DisposableObject + internal class SessionOptions : DisposableObject { - public SessionOptions() + public SessionOptions(string target = "", ConfigProto config = null) { _handle = c_api.TF_NewSessionOptions(); + c_api.TF_SetTarget(_handle, target); + if (config != null) + SetConfig(config); } public SessionOptions(IntPtr handle) @@ -35,10 +38,10 @@ namespace Tensorflow protected override void DisposeUnmanagedResources(IntPtr handle) => c_api.TF_DeleteSessionOptions(handle); - public void SetConfig(ConfigProto config) + private void SetConfig(ConfigProto config) { - var bytes = config.ToByteArray(); //TODO! we can use WriteTo - var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak + var bytes = config.ToByteArray(); + var proto = Marshal.AllocHGlobal(bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length); using (var status = new Status()) diff --git a/src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs b/src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs deleted file mode 100644 index baab71a0..00000000 --- a/src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs +++ /dev/null @@ -1,26 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System.Runtime.InteropServices; - -namespace Tensorflow.Sessions -{ - [StructLayout(LayoutKind.Sequential)] - public struct TF_DeprecatedSession - { - Session session; - } -} diff --git a/src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs deleted file mode 100644 index 35b27f12..00000000 --- a/src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs +++ /dev/null @@ -1,10 +0,0 @@ -using System.Runtime.InteropServices; - -namespace Tensorflow -{ - [StructLayout(LayoutKind.Sequential)] - public struct TF_SessionOptions - { - public SessionOptions options; - } -} diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 3ed60435..713d0d5f 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -116,6 +116,9 @@ namespace Tensorflow /// size_t /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status); + public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_SetTarget(IntPtr options, string target); } } From a589552ca938787b232dfe5caaca22fe490168b5 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 21 Oct 2019 06:03:53 -0500 Subject: [PATCH 06/41] update ConfigProto. --- src/TensorFlowNET.Core/Protobuf/Config.cs | 723 +++++++++++++++++++--- 1 file changed, 645 insertions(+), 78 deletions(-) diff --git a/src/TensorFlowNET.Core/Protobuf/Config.cs b/src/TensorFlowNET.Core/Protobuf/Config.cs index 7eb798e3..be63c8c7 100644 --- a/src/TensorFlowNET.Core/Protobuf/Config.cs +++ b/src/TensorFlowNET.Core/Protobuf/Config.cs @@ -27,10 +27,10 @@ namespace Tensorflow { "CiV0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29uZmlnLnByb3RvEgp0ZW5z", "b3JmbG93Gip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Nvc3RfZ3JhcGgu", "cHJvdG8aJXRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvZ3JhcGgucHJvdG8a", - "KnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvc3RlcF9zdGF0cy5wcm90bxok", - "dGVuc29yZmxvdy9jb3JlL3Byb3RvYnVmL2RlYnVnLnByb3RvGiZ0ZW5zb3Jm", - "bG93L2NvcmUvcHJvdG9idWYvY2x1c3Rlci5wcm90bxoudGVuc29yZmxvdy9j", - "b3JlL3Byb3RvYnVmL3Jld3JpdGVyX2NvbmZpZy5wcm90byKtBAoKR1BVT3B0", + "KnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvc3RlcF9zdGF0cy5wcm90bxom", + "dGVuc29yZmxvdy9jb3JlL3Byb3RvYnVmL2NsdXN0ZXIucHJvdG8aJHRlbnNv", + "cmZsb3cvY29yZS9wcm90b2J1Zi9kZWJ1Zy5wcm90bxoudGVuc29yZmxvdy9j", + "b3JlL3Byb3RvYnVmL3Jld3JpdGVyX2NvbmZpZy5wcm90byK3BQoKR1BVT3B0", "aW9ucxInCh9wZXJfcHJvY2Vzc19ncHVfbWVtb3J5X2ZyYWN0aW9uGAEgASgB", "EhQKDGFsbG93X2dyb3d0aBgEIAEoCBIWCg5hbGxvY2F0b3JfdHlwZRgCIAEo", "CRIfChdkZWZlcnJlZF9kZWxldGlvbl9ieXRlcxgDIAEoAxIbChN2aXNpYmxl", @@ -38,89 +38,102 @@ namespace Tensorflow { "ZWNzGAYgASgFEiQKHHBvbGxpbmdfaW5hY3RpdmVfZGVsYXlfbXNlY3MYByAB", "KAUSHAoUZm9yY2VfZ3B1X2NvbXBhdGlibGUYCCABKAgSOQoMZXhwZXJpbWVu", "dGFsGAkgASgLMiMudGVuc29yZmxvdy5HUFVPcHRpb25zLkV4cGVyaW1lbnRh", - "bBrmAQoMRXhwZXJpbWVudGFsEksKD3ZpcnR1YWxfZGV2aWNlcxgBIAMoCzIy", + "bBrwAgoMRXhwZXJpbWVudGFsEksKD3ZpcnR1YWxfZGV2aWNlcxgBIAMoCzIy", "LnRlbnNvcmZsb3cuR1BVT3B0aW9ucy5FeHBlcmltZW50YWwuVmlydHVhbERl", "dmljZXMSGgoSdXNlX3VuaWZpZWRfbWVtb3J5GAIgASgIEiMKG251bV9kZXZf", "dG9fZGV2X2NvcHlfc3RyZWFtcxgDIAEoBRIdChVjb2xsZWN0aXZlX3Jpbmdf", - "b3JkZXIYBCABKAkaKQoOVmlydHVhbERldmljZXMSFwoPbWVtb3J5X2xpbWl0", - "X21iGAEgAygCIoUDChBPcHRpbWl6ZXJPcHRpb25zEisKI2RvX2NvbW1vbl9z", - "dWJleHByZXNzaW9uX2VsaW1pbmF0aW9uGAEgASgIEhsKE2RvX2NvbnN0YW50", - "X2ZvbGRpbmcYAiABKAgSJAocbWF4X2ZvbGRlZF9jb25zdGFudF9pbl9ieXRl", - "cxgGIAEoAxIcChRkb19mdW5jdGlvbl9pbmxpbmluZxgEIAEoCBI1CglvcHRf", - "bGV2ZWwYAyABKA4yIi50ZW5zb3JmbG93Lk9wdGltaXplck9wdGlvbnMuTGV2", - "ZWwSRQoQZ2xvYmFsX2ppdF9sZXZlbBgFIAEoDjIrLnRlbnNvcmZsb3cuT3B0", - "aW1pemVyT3B0aW9ucy5HbG9iYWxKaXRMZXZlbCIgCgVMZXZlbBIGCgJMMRAA", - "Eg8KAkwwEP///////////wEiQwoOR2xvYmFsSml0TGV2ZWwSCwoHREVGQVVM", - "VBAAEhAKA09GRhD///////////8BEggKBE9OXzEQARIICgRPTl8yEAIi7gIK", - "DEdyYXBoT3B0aW9ucxIeChZlbmFibGVfcmVjdl9zY2hlZHVsaW5nGAIgASgI", - "EjcKEW9wdGltaXplcl9vcHRpb25zGAMgASgLMhwudGVuc29yZmxvdy5PcHRp", - "bWl6ZXJPcHRpb25zEhgKEGJ1aWxkX2Nvc3RfbW9kZWwYBCABKAMSHgoWYnVp", - "bGRfY29zdF9tb2RlbF9hZnRlchgJIAEoAxIUCgxpbmZlcl9zaGFwZXMYBSAB", - "KAgSGgoScGxhY2VfcHJ1bmVkX2dyYXBoGAYgASgIEiAKGGVuYWJsZV9iZmxv", - "YXQxNl9zZW5kcmVjdhgHIAEoCBIVCg10aW1lbGluZV9zdGVwGAggASgFEjMK", - "D3Jld3JpdGVfb3B0aW9ucxgKIAEoCzIaLnRlbnNvcmZsb3cuUmV3cml0ZXJD", - "b25maWdKBAgBEAJSJXNraXBfY29tbW9uX3N1YmV4cHJlc3Npb25fZWxpbWlu", - "YXRpb24iQQoVVGhyZWFkUG9vbE9wdGlvblByb3RvEhMKC251bV90aHJlYWRz", - "GAEgASgFEhMKC2dsb2JhbF9uYW1lGAIgASgJImwKClJQQ09wdGlvbnMSJAoc", - "dXNlX3JwY19mb3JfaW5wcm9jZXNzX21hc3RlchgBIAEoCBIdChVjb21wcmVz", - "c2lvbl9hbGdvcml0aG0YAiABKAkSGQoRY29tcHJlc3Npb25fbGV2ZWwYAyAB", - "KAUi3wYKC0NvbmZpZ1Byb3RvEj4KDGRldmljZV9jb3VudBgBIAMoCzIoLnRl", - "bnNvcmZsb3cuQ29uZmlnUHJvdG8uRGV2aWNlQ291bnRFbnRyeRIkChxpbnRy", - "YV9vcF9wYXJhbGxlbGlzbV90aHJlYWRzGAIgASgFEiQKHGludGVyX29wX3Bh", - "cmFsbGVsaXNtX3RocmVhZHMYBSABKAUSHwoXdXNlX3Blcl9zZXNzaW9uX3Ro", - "cmVhZHMYCSABKAgSRwocc2Vzc2lvbl9pbnRlcl9vcF90aHJlYWRfcG9vbBgM", - "IAMoCzIhLnRlbnNvcmZsb3cuVGhyZWFkUG9vbE9wdGlvblByb3RvEhgKEHBs", - "YWNlbWVudF9wZXJpb2QYAyABKAUSFgoOZGV2aWNlX2ZpbHRlcnMYBCADKAkS", - "KwoLZ3B1X29wdGlvbnMYBiABKAsyFi50ZW5zb3JmbG93LkdQVU9wdGlvbnMS", - "HAoUYWxsb3dfc29mdF9wbGFjZW1lbnQYByABKAgSHAoUbG9nX2RldmljZV9w", - "bGFjZW1lbnQYCCABKAgSLwoNZ3JhcGhfb3B0aW9ucxgKIAEoCzIYLnRlbnNv", - "cmZsb3cuR3JhcGhPcHRpb25zEh8KF29wZXJhdGlvbl90aW1lb3V0X2luX21z", - "GAsgASgDEisKC3JwY19vcHRpb25zGA0gASgLMhYudGVuc29yZmxvdy5SUENP", - "cHRpb25zEisKC2NsdXN0ZXJfZGVmGA4gASgLMhYudGVuc29yZmxvdy5DbHVz", - "dGVyRGVmEh0KFWlzb2xhdGVfc2Vzc2lvbl9zdGF0ZRgPIAEoCBI6CgxleHBl", - "cmltZW50YWwYECABKAsyJC50ZW5zb3JmbG93LkNvbmZpZ1Byb3RvLkV4cGVy", - "aW1lbnRhbBoyChBEZXZpY2VDb3VudEVudHJ5EgsKA2tleRgBIAEoCRINCgV2", - "YWx1ZRgCIAEoBToCOAEagwEKDEV4cGVyaW1lbnRhbBIfChdjb2xsZWN0aXZl", - "X2dyb3VwX2xlYWRlchgBIAEoCRIVCg1leGVjdXRvcl90eXBlGAMgASgJEhoK", - "EnJlY3ZfYnVmX21heF9jaHVuaxgEIAEoBRIZChF1c2VfbnVtYV9hZmZpbml0", - "eRgFIAEoCEoECAIQAyLYAwoKUnVuT3B0aW9ucxI2Cgt0cmFjZV9sZXZlbBgB", - "IAEoDjIhLnRlbnNvcmZsb3cuUnVuT3B0aW9ucy5UcmFjZUxldmVsEhUKDXRp", - "bWVvdXRfaW5fbXMYAiABKAMSHAoUaW50ZXJfb3BfdGhyZWFkX3Bvb2wYAyAB", - "KAUSHwoXb3V0cHV0X3BhcnRpdGlvbl9ncmFwaHMYBSABKAgSLwoNZGVidWdf", - "b3B0aW9ucxgGIAEoCzIYLnRlbnNvcmZsb3cuRGVidWdPcHRpb25zEioKInJl", - "cG9ydF90ZW5zb3JfYWxsb2NhdGlvbnNfdXBvbl9vb20YByABKAgSOQoMZXhw", - "ZXJpbWVudGFsGAggASgLMiMudGVuc29yZmxvdy5SdW5PcHRpb25zLkV4cGVy", - "aW1lbnRhbBpKCgxFeHBlcmltZW50YWwSHAoUY29sbGVjdGl2ZV9ncmFwaF9r", - "ZXkYASABKAMSHAoUdXNlX3J1bl9oYW5kbGVyX3Bvb2wYAiABKAgiUgoKVHJh", - "Y2VMZXZlbBIMCghOT19UUkFDRRAAEhIKDlNPRlRXQVJFX1RSQUNFEAESEgoO", - "SEFSRFdBUkVfVFJBQ0UQAhIOCgpGVUxMX1RSQUNFEANKBAgEEAUilgEKC1J1", - "bk1ldGFkYXRhEikKCnN0ZXBfc3RhdHMYASABKAsyFS50ZW5zb3JmbG93LlN0", - "ZXBTdGF0cxIsCgpjb3N0X2dyYXBoGAIgASgLMhgudGVuc29yZmxvdy5Db3N0", - "R3JhcGhEZWYSLgoQcGFydGl0aW9uX2dyYXBocxgDIAMoCzIULnRlbnNvcmZs", - "b3cuR3JhcGhEZWYiOgoQVGVuc29yQ29ubmVjdGlvbhITCgtmcm9tX3RlbnNv", - "chgBIAEoCRIRCgl0b190ZW5zb3IYAiABKAkisAMKD0NhbGxhYmxlT3B0aW9u", - "cxIMCgRmZWVkGAEgAygJEg0KBWZldGNoGAIgAygJEg4KBnRhcmdldBgDIAMo", - "CRIrCgtydW5fb3B0aW9ucxgEIAEoCzIWLnRlbnNvcmZsb3cuUnVuT3B0aW9u", - "cxI3ChF0ZW5zb3JfY29ubmVjdGlvbhgFIAMoCzIcLnRlbnNvcmZsb3cuVGVu", - "c29yQ29ubmVjdGlvbhJCCgxmZWVkX2RldmljZXMYBiADKAsyLC50ZW5zb3Jm", - "bG93LkNhbGxhYmxlT3B0aW9ucy5GZWVkRGV2aWNlc0VudHJ5EkQKDWZldGNo", - "X2RldmljZXMYByADKAsyLS50ZW5zb3JmbG93LkNhbGxhYmxlT3B0aW9ucy5G", - "ZXRjaERldmljZXNFbnRyeRIXCg9mZXRjaF9za2lwX3N5bmMYCCABKAgaMgoQ", - "RmVlZERldmljZXNFbnRyeRILCgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAk6", - "AjgBGjMKEUZldGNoRGV2aWNlc0VudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1", - "ZRgCIAEoCToCOAFCLQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQgxDb25m", - "aWdQcm90b3NQAfgBAWIGcHJvdG8z")); + "b3JkZXIYBCABKAkSHQoVdGltZXN0YW1wZWRfYWxsb2NhdG9yGAUgASgIEiMK", + "G2tlcm5lbF90cmFja2VyX21heF9pbnRlcnZhbBgHIAEoBRIgChhrZXJuZWxf", + "dHJhY2tlcl9tYXhfYnl0ZXMYCCABKAUSIgoaa2VybmVsX3RyYWNrZXJfbWF4", + "X3BlbmRpbmcYCSABKAUaKQoOVmlydHVhbERldmljZXMSFwoPbWVtb3J5X2xp", + "bWl0X21iGAEgAygCIoUDChBPcHRpbWl6ZXJPcHRpb25zEisKI2RvX2NvbW1v", + "bl9zdWJleHByZXNzaW9uX2VsaW1pbmF0aW9uGAEgASgIEhsKE2RvX2NvbnN0", + "YW50X2ZvbGRpbmcYAiABKAgSJAocbWF4X2ZvbGRlZF9jb25zdGFudF9pbl9i", + "eXRlcxgGIAEoAxIcChRkb19mdW5jdGlvbl9pbmxpbmluZxgEIAEoCBI1Cglv", + "cHRfbGV2ZWwYAyABKA4yIi50ZW5zb3JmbG93Lk9wdGltaXplck9wdGlvbnMu", + "TGV2ZWwSRQoQZ2xvYmFsX2ppdF9sZXZlbBgFIAEoDjIrLnRlbnNvcmZsb3cu", + "T3B0aW1pemVyT3B0aW9ucy5HbG9iYWxKaXRMZXZlbCIgCgVMZXZlbBIGCgJM", + "MRAAEg8KAkwwEP///////////wEiQwoOR2xvYmFsSml0TGV2ZWwSCwoHREVG", + "QVVMVBAAEhAKA09GRhD///////////8BEggKBE9OXzEQARIICgRPTl8yEAIi", + "7gIKDEdyYXBoT3B0aW9ucxIeChZlbmFibGVfcmVjdl9zY2hlZHVsaW5nGAIg", + "ASgIEjcKEW9wdGltaXplcl9vcHRpb25zGAMgASgLMhwudGVuc29yZmxvdy5P", + "cHRpbWl6ZXJPcHRpb25zEhgKEGJ1aWxkX2Nvc3RfbW9kZWwYBCABKAMSHgoW", + "YnVpbGRfY29zdF9tb2RlbF9hZnRlchgJIAEoAxIUCgxpbmZlcl9zaGFwZXMY", + "BSABKAgSGgoScGxhY2VfcHJ1bmVkX2dyYXBoGAYgASgIEiAKGGVuYWJsZV9i", + "ZmxvYXQxNl9zZW5kcmVjdhgHIAEoCBIVCg10aW1lbGluZV9zdGVwGAggASgF", + "EjMKD3Jld3JpdGVfb3B0aW9ucxgKIAEoCzIaLnRlbnNvcmZsb3cuUmV3cml0", + "ZXJDb25maWdKBAgBEAJSJXNraXBfY29tbW9uX3N1YmV4cHJlc3Npb25fZWxp", + "bWluYXRpb24iQQoVVGhyZWFkUG9vbE9wdGlvblByb3RvEhMKC251bV90aHJl", + "YWRzGAEgASgFEhMKC2dsb2JhbF9uYW1lGAIgASgJImwKClJQQ09wdGlvbnMS", + "JAocdXNlX3JwY19mb3JfaW5wcm9jZXNzX21hc3RlchgBIAEoCBIdChVjb21w", + "cmVzc2lvbl9hbGdvcml0aG0YAiABKAkSGQoRY29tcHJlc3Npb25fbGV2ZWwY", + "AyABKAUisggKC0NvbmZpZ1Byb3RvEj4KDGRldmljZV9jb3VudBgBIAMoCzIo", + "LnRlbnNvcmZsb3cuQ29uZmlnUHJvdG8uRGV2aWNlQ291bnRFbnRyeRIkChxp", + "bnRyYV9vcF9wYXJhbGxlbGlzbV90aHJlYWRzGAIgASgFEiQKHGludGVyX29w", + "X3BhcmFsbGVsaXNtX3RocmVhZHMYBSABKAUSHwoXdXNlX3Blcl9zZXNzaW9u", + "X3RocmVhZHMYCSABKAgSRwocc2Vzc2lvbl9pbnRlcl9vcF90aHJlYWRfcG9v", + "bBgMIAMoCzIhLnRlbnNvcmZsb3cuVGhyZWFkUG9vbE9wdGlvblByb3RvEhgK", + "EHBsYWNlbWVudF9wZXJpb2QYAyABKAUSFgoOZGV2aWNlX2ZpbHRlcnMYBCAD", + "KAkSKwoLZ3B1X29wdGlvbnMYBiABKAsyFi50ZW5zb3JmbG93LkdQVU9wdGlv", + "bnMSHAoUYWxsb3dfc29mdF9wbGFjZW1lbnQYByABKAgSHAoUbG9nX2Rldmlj", + "ZV9wbGFjZW1lbnQYCCABKAgSLwoNZ3JhcGhfb3B0aW9ucxgKIAEoCzIYLnRl", + "bnNvcmZsb3cuR3JhcGhPcHRpb25zEh8KF29wZXJhdGlvbl90aW1lb3V0X2lu", + "X21zGAsgASgDEisKC3JwY19vcHRpb25zGA0gASgLMhYudGVuc29yZmxvdy5S", + "UENPcHRpb25zEisKC2NsdXN0ZXJfZGVmGA4gASgLMhYudGVuc29yZmxvdy5D", + "bHVzdGVyRGVmEh0KFWlzb2xhdGVfc2Vzc2lvbl9zdGF0ZRgPIAEoCBI6Cgxl", + "eHBlcmltZW50YWwYECABKAsyJC50ZW5zb3JmbG93LkNvbmZpZ1Byb3RvLkV4", + "cGVyaW1lbnRhbBoyChBEZXZpY2VDb3VudEVudHJ5EgsKA2tleRgBIAEoCRIN", + "CgV2YWx1ZRgCIAEoBToCOAEa1gIKDEV4cGVyaW1lbnRhbBIfChdjb2xsZWN0", + "aXZlX2dyb3VwX2xlYWRlchgBIAEoCRIVCg1leGVjdXRvcl90eXBlGAMgASgJ", + "EhoKEnJlY3ZfYnVmX21heF9jaHVuaxgEIAEoBRIZChF1c2VfbnVtYV9hZmZp", + "bml0eRgFIAEoCBI1Ci1jb2xsZWN0aXZlX2RldGVybWluaXN0aWNfc2VxdWVu", + "dGlhbF9leGVjdXRpb24YBiABKAgSFwoPY29sbGVjdGl2ZV9uY2NsGAcgASgI", + "EjYKLnNoYXJlX3Nlc3Npb25fc3RhdGVfaW5fY2x1c3RlcnNwZWNfcHJvcGFn", + "YXRpb24YCCABKAgSHwoXZGlzYWJsZV90aHJlYWRfc3Bpbm5pbmcYCSABKAgS", + "KAogc2hhcmVfY2x1c3Rlcl9kZXZpY2VzX2luX3Nlc3Npb24YCiABKAhKBAgC", + "EAMi2AMKClJ1bk9wdGlvbnMSNgoLdHJhY2VfbGV2ZWwYASABKA4yIS50ZW5z", + "b3JmbG93LlJ1bk9wdGlvbnMuVHJhY2VMZXZlbBIVCg10aW1lb3V0X2luX21z", + "GAIgASgDEhwKFGludGVyX29wX3RocmVhZF9wb29sGAMgASgFEh8KF291dHB1", + "dF9wYXJ0aXRpb25fZ3JhcGhzGAUgASgIEi8KDWRlYnVnX29wdGlvbnMYBiAB", + "KAsyGC50ZW5zb3JmbG93LkRlYnVnT3B0aW9ucxIqCiJyZXBvcnRfdGVuc29y", + "X2FsbG9jYXRpb25zX3Vwb25fb29tGAcgASgIEjkKDGV4cGVyaW1lbnRhbBgI", + "IAEoCzIjLnRlbnNvcmZsb3cuUnVuT3B0aW9ucy5FeHBlcmltZW50YWwaSgoM", + "RXhwZXJpbWVudGFsEhwKFGNvbGxlY3RpdmVfZ3JhcGhfa2V5GAEgASgDEhwK", + "FHVzZV9ydW5faGFuZGxlcl9wb29sGAIgASgIIlIKClRyYWNlTGV2ZWwSDAoI", + "Tk9fVFJBQ0UQABISCg5TT0ZUV0FSRV9UUkFDRRABEhIKDkhBUkRXQVJFX1RS", + "QUNFEAISDgoKRlVMTF9UUkFDRRADSgQIBBAFIocDCgtSdW5NZXRhZGF0YRIp", + "CgpzdGVwX3N0YXRzGAEgASgLMhUudGVuc29yZmxvdy5TdGVwU3RhdHMSLAoK", + "Y29zdF9ncmFwaBgCIAEoCzIYLnRlbnNvcmZsb3cuQ29zdEdyYXBoRGVmEi4K", + "EHBhcnRpdGlvbl9ncmFwaHMYAyADKAsyFC50ZW5zb3JmbG93LkdyYXBoRGVm", + "Ej8KD2Z1bmN0aW9uX2dyYXBocxgEIAMoCzImLnRlbnNvcmZsb3cuUnVuTWV0", + "YWRhdGEuRnVuY3Rpb25HcmFwaHMarQEKDkZ1bmN0aW9uR3JhcGhzEi4KEHBh", + "cnRpdGlvbl9ncmFwaHMYASADKAsyFC50ZW5zb3JmbG93LkdyYXBoRGVmEjQK", + "FnByZV9vcHRpbWl6YXRpb25fZ3JhcGgYAiABKAsyFC50ZW5zb3JmbG93Lkdy", + "YXBoRGVmEjUKF3Bvc3Rfb3B0aW1pemF0aW9uX2dyYXBoGAMgASgLMhQudGVu", + "c29yZmxvdy5HcmFwaERlZiI6ChBUZW5zb3JDb25uZWN0aW9uEhMKC2Zyb21f", + "dGVuc29yGAEgASgJEhEKCXRvX3RlbnNvchgCIAEoCSKwAwoPQ2FsbGFibGVP", + "cHRpb25zEgwKBGZlZWQYASADKAkSDQoFZmV0Y2gYAiADKAkSDgoGdGFyZ2V0", + "GAMgAygJEisKC3J1bl9vcHRpb25zGAQgASgLMhYudGVuc29yZmxvdy5SdW5P", + "cHRpb25zEjcKEXRlbnNvcl9jb25uZWN0aW9uGAUgAygLMhwudGVuc29yZmxv", + "dy5UZW5zb3JDb25uZWN0aW9uEkIKDGZlZWRfZGV2aWNlcxgGIAMoCzIsLnRl", + "bnNvcmZsb3cuQ2FsbGFibGVPcHRpb25zLkZlZWREZXZpY2VzRW50cnkSRAoN", + "ZmV0Y2hfZGV2aWNlcxgHIAMoCzItLnRlbnNvcmZsb3cuQ2FsbGFibGVPcHRp", + "b25zLkZldGNoRGV2aWNlc0VudHJ5EhcKD2ZldGNoX3NraXBfc3luYxgIIAEo", + "CBoyChBGZWVkRGV2aWNlc0VudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgC", + "IAEoCToCOAEaMwoRRmV0Y2hEZXZpY2VzRW50cnkSCwoDa2V5GAEgASgJEg0K", + "BXZhbHVlGAIgASgJOgI4AUItChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtC", + "DENvbmZpZ1Byb3Rvc1AB+AEBYgZwcm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, - new pbr::FileDescriptor[] { global::Tensorflow.CostGraphReflection.Descriptor, global::Tensorflow.GraphReflection.Descriptor, global::Tensorflow.StepStatsReflection.Descriptor, global::Tensorflow.DebugReflection.Descriptor, global::Tensorflow.ClusterReflection.Descriptor, global::Tensorflow.RewriterConfigReflection.Descriptor, }, + new pbr::FileDescriptor[] { global::Tensorflow.CostGraphReflection.Descriptor, global::Tensorflow.GraphReflection.Descriptor, global::Tensorflow.StepStatsReflection.Descriptor, global::Tensorflow.ClusterReflection.Descriptor, global::Tensorflow.DebugReflection.Descriptor, global::Tensorflow.RewriterConfigReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions), global::Tensorflow.GPUOptions.Parser, new[]{ "PerProcessGpuMemoryFraction", "AllowGrowth", "AllocatorType", "DeferredDeletionBytes", "VisibleDeviceList", "PollingActiveDelayUsecs", "PollingInactiveDelayMsecs", "ForceGpuCompatible", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental), global::Tensorflow.GPUOptions.Types.Experimental.Parser, new[]{ "VirtualDevices", "UseUnifiedMemory", "NumDevToDevCopyStreams", "CollectiveRingOrder" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices), global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices.Parser, new[]{ "MemoryLimitMb" }, null, null, null)})}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions), global::Tensorflow.GPUOptions.Parser, new[]{ "PerProcessGpuMemoryFraction", "AllowGrowth", "AllocatorType", "DeferredDeletionBytes", "VisibleDeviceList", "PollingActiveDelayUsecs", "PollingInactiveDelayMsecs", "ForceGpuCompatible", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental), global::Tensorflow.GPUOptions.Types.Experimental.Parser, new[]{ "VirtualDevices", "UseUnifiedMemory", "NumDevToDevCopyStreams", "CollectiveRingOrder", "TimestampedAllocator", "KernelTrackerMaxInterval", "KernelTrackerMaxBytes", "KernelTrackerMaxPending" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices), global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices.Parser, new[]{ "MemoryLimitMb" }, null, null, null)})}), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OptimizerOptions), global::Tensorflow.OptimizerOptions.Parser, new[]{ "DoCommonSubexpressionElimination", "DoConstantFolding", "MaxFoldedConstantInBytes", "DoFunctionInlining", "OptLevel", "GlobalJitLevel" }, null, new[]{ typeof(global::Tensorflow.OptimizerOptions.Types.Level), typeof(global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel) }, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphOptions), global::Tensorflow.GraphOptions.Parser, new[]{ "EnableRecvScheduling", "OptimizerOptions", "BuildCostModel", "BuildCostModelAfter", "InferShapes", "PlacePrunedGraph", "EnableBfloat16Sendrecv", "TimelineStep", "RewriteOptions" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ThreadPoolOptionProto), global::Tensorflow.ThreadPoolOptionProto.Parser, new[]{ "NumThreads", "GlobalName" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RPCOptions), global::Tensorflow.RPCOptions.Parser, new[]{ "UseRpcForInprocessMaster", "CompressionAlgorithm", "CompressionLevel" }, null, null, null), - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto), global::Tensorflow.ConfigProto.Parser, new[]{ "DeviceCount", "IntraOpParallelismThreads", "InterOpParallelismThreads", "UsePerSessionThreads", "SessionInterOpThreadPool", "PlacementPeriod", "DeviceFilters", "GpuOptions", "AllowSoftPlacement", "LogDevicePlacement", "GraphOptions", "OperationTimeoutInMs", "RpcOptions", "ClusterDef", "IsolateSessionState", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto.Types.Experimental), global::Tensorflow.ConfigProto.Types.Experimental.Parser, new[]{ "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto), global::Tensorflow.ConfigProto.Parser, new[]{ "DeviceCount", "IntraOpParallelismThreads", "InterOpParallelismThreads", "UsePerSessionThreads", "SessionInterOpThreadPool", "PlacementPeriod", "DeviceFilters", "GpuOptions", "AllowSoftPlacement", "LogDevicePlacement", "GraphOptions", "OperationTimeoutInMs", "RpcOptions", "ClusterDef", "IsolateSessionState", "Experimental" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto.Types.Experimental), global::Tensorflow.ConfigProto.Types.Experimental.Parser, new[]{ "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity", "CollectiveDeterministicSequentialExecution", "CollectiveNccl", "ShareSessionStateInClusterspecPropagation", "DisableThreadSpinning", "ShareClusterDevicesInSession" }, null, null, null)}), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions), global::Tensorflow.RunOptions.Parser, new[]{ "TraceLevel", "TimeoutInMs", "InterOpThreadPool", "OutputPartitionGraphs", "DebugOptions", "ReportTensorAllocationsUponOom", "Experimental" }, null, new[]{ typeof(global::Tensorflow.RunOptions.Types.TraceLevel) }, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions.Types.Experimental), global::Tensorflow.RunOptions.Types.Experimental.Parser, new[]{ "CollectiveGraphKey", "UseRunHandlerPool" }, null, null, null)}), - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata), global::Tensorflow.RunMetadata.Parser, new[]{ "StepStats", "CostGraph", "PartitionGraphs" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata), global::Tensorflow.RunMetadata.Parser, new[]{ "StepStats", "CostGraph", "PartitionGraphs", "FunctionGraphs" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata.Types.FunctionGraphs), global::Tensorflow.RunMetadata.Types.FunctionGraphs.Parser, new[]{ "PartitionGraphs", "PreOptimizationGraph", "PostOptimizationGraph" }, null, null, null)}), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorConnection), global::Tensorflow.TensorConnection.Parser, new[]{ "FromTensor", "ToTensor" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CallableOptions), global::Tensorflow.CallableOptions.Parser, new[]{ "Feed", "Fetch", "Target", "RunOptions", "TensorConnection", "FeedDevices", "FetchDevices", "FetchSkipSync" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }) })); @@ -605,6 +618,10 @@ namespace Tensorflow { useUnifiedMemory_ = other.useUnifiedMemory_; numDevToDevCopyStreams_ = other.numDevToDevCopyStreams_; collectiveRingOrder_ = other.collectiveRingOrder_; + timestampedAllocator_ = other.timestampedAllocator_; + kernelTrackerMaxInterval_ = other.kernelTrackerMaxInterval_; + kernelTrackerMaxBytes_ = other.kernelTrackerMaxBytes_; + kernelTrackerMaxPending_ = other.kernelTrackerMaxPending_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -703,6 +720,77 @@ namespace Tensorflow { } } + /// Field number for the "timestamped_allocator" field. + public const int TimestampedAllocatorFieldNumber = 5; + private bool timestampedAllocator_; + /// + /// If true then extra work is done by GPUDevice and GPUBFCAllocator to + /// keep track of when GPU memory is freed and when kernels actually + /// complete so that we can know when a nominally free memory chunk + /// is really not subject to pending use. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool TimestampedAllocator { + get { return timestampedAllocator_; } + set { + timestampedAllocator_ = value; + } + } + + /// Field number for the "kernel_tracker_max_interval" field. + public const int KernelTrackerMaxIntervalFieldNumber = 7; + private int kernelTrackerMaxInterval_; + /// + /// Parameters for GPUKernelTracker. By default no kernel tracking is done. + /// Note that timestamped_allocator is only effective if some tracking is + /// specified. + /// + /// If kernel_tracker_max_interval = n > 0, then a tracking event + /// is inserted after every n kernels without an event. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int KernelTrackerMaxInterval { + get { return kernelTrackerMaxInterval_; } + set { + kernelTrackerMaxInterval_ = value; + } + } + + /// Field number for the "kernel_tracker_max_bytes" field. + public const int KernelTrackerMaxBytesFieldNumber = 8; + private int kernelTrackerMaxBytes_; + /// + /// If kernel_tracker_max_bytes = n > 0, then a tracking event is + /// inserted after every series of kernels allocating a sum of + /// memory >= n. If one kernel allocates b * n bytes, then one + /// event will be inserted after it, but it will count as b against + /// the pending limit. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int KernelTrackerMaxBytes { + get { return kernelTrackerMaxBytes_; } + set { + kernelTrackerMaxBytes_ = value; + } + } + + /// Field number for the "kernel_tracker_max_pending" field. + public const int KernelTrackerMaxPendingFieldNumber = 9; + private int kernelTrackerMaxPending_; + /// + /// If kernel_tracker_max_pending > 0 then no more than this many + /// tracking events can be outstanding at a time. An attempt to + /// launch an additional kernel will stall until an event + /// completes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int KernelTrackerMaxPending { + get { return kernelTrackerMaxPending_; } + set { + kernelTrackerMaxPending_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as Experimental); @@ -720,6 +808,10 @@ namespace Tensorflow { if (UseUnifiedMemory != other.UseUnifiedMemory) return false; if (NumDevToDevCopyStreams != other.NumDevToDevCopyStreams) return false; if (CollectiveRingOrder != other.CollectiveRingOrder) return false; + if (TimestampedAllocator != other.TimestampedAllocator) return false; + if (KernelTrackerMaxInterval != other.KernelTrackerMaxInterval) return false; + if (KernelTrackerMaxBytes != other.KernelTrackerMaxBytes) return false; + if (KernelTrackerMaxPending != other.KernelTrackerMaxPending) return false; return Equals(_unknownFields, other._unknownFields); } @@ -730,6 +822,10 @@ namespace Tensorflow { if (UseUnifiedMemory != false) hash ^= UseUnifiedMemory.GetHashCode(); if (NumDevToDevCopyStreams != 0) hash ^= NumDevToDevCopyStreams.GetHashCode(); if (CollectiveRingOrder.Length != 0) hash ^= CollectiveRingOrder.GetHashCode(); + if (TimestampedAllocator != false) hash ^= TimestampedAllocator.GetHashCode(); + if (KernelTrackerMaxInterval != 0) hash ^= KernelTrackerMaxInterval.GetHashCode(); + if (KernelTrackerMaxBytes != 0) hash ^= KernelTrackerMaxBytes.GetHashCode(); + if (KernelTrackerMaxPending != 0) hash ^= KernelTrackerMaxPending.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -756,6 +852,22 @@ namespace Tensorflow { output.WriteRawTag(34); output.WriteString(CollectiveRingOrder); } + if (TimestampedAllocator != false) { + output.WriteRawTag(40); + output.WriteBool(TimestampedAllocator); + } + if (KernelTrackerMaxInterval != 0) { + output.WriteRawTag(56); + output.WriteInt32(KernelTrackerMaxInterval); + } + if (KernelTrackerMaxBytes != 0) { + output.WriteRawTag(64); + output.WriteInt32(KernelTrackerMaxBytes); + } + if (KernelTrackerMaxPending != 0) { + output.WriteRawTag(72); + output.WriteInt32(KernelTrackerMaxPending); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -774,6 +886,18 @@ namespace Tensorflow { if (CollectiveRingOrder.Length != 0) { size += 1 + pb::CodedOutputStream.ComputeStringSize(CollectiveRingOrder); } + if (TimestampedAllocator != false) { + size += 1 + 1; + } + if (KernelTrackerMaxInterval != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxInterval); + } + if (KernelTrackerMaxBytes != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxBytes); + } + if (KernelTrackerMaxPending != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxPending); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -795,6 +919,18 @@ namespace Tensorflow { if (other.CollectiveRingOrder.Length != 0) { CollectiveRingOrder = other.CollectiveRingOrder; } + if (other.TimestampedAllocator != false) { + TimestampedAllocator = other.TimestampedAllocator; + } + if (other.KernelTrackerMaxInterval != 0) { + KernelTrackerMaxInterval = other.KernelTrackerMaxInterval; + } + if (other.KernelTrackerMaxBytes != 0) { + KernelTrackerMaxBytes = other.KernelTrackerMaxBytes; + } + if (other.KernelTrackerMaxPending != 0) { + KernelTrackerMaxPending = other.KernelTrackerMaxPending; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -822,6 +958,22 @@ namespace Tensorflow { CollectiveRingOrder = input.ReadString(); break; } + case 40: { + TimestampedAllocator = input.ReadBool(); + break; + } + case 56: { + KernelTrackerMaxInterval = input.ReadInt32(); + break; + } + case 64: { + KernelTrackerMaxBytes = input.ReadInt32(); + break; + } + case 72: { + KernelTrackerMaxPending = input.ReadInt32(); + break; + } } } } @@ -2189,6 +2341,7 @@ namespace Tensorflow { /// inter_op_parallelism_threads available in each process. /// /// 0 means the system picks an appropriate number. + /// Negative means all operations are performed in caller's thread. /// /// Note that the first Session created in the process sets the /// number of threads for all future sessions unless use_per_session_threads is @@ -2397,7 +2550,8 @@ namespace Tensorflow { private bool isolateSessionState_; /// /// If true, any resources such as Variables used in the session will not be - /// shared with other sessions. + /// shared with other sessions. However, when clusterspec propagation is + /// enabled, this field is ignored and sessions are always isolated. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public bool IsolateSessionState { @@ -2787,6 +2941,11 @@ namespace Tensorflow { executorType_ = other.executorType_; recvBufMaxChunk_ = other.recvBufMaxChunk_; useNumaAffinity_ = other.useNumaAffinity_; + collectiveDeterministicSequentialExecution_ = other.collectiveDeterministicSequentialExecution_; + collectiveNccl_ = other.collectiveNccl_; + shareSessionStateInClusterspecPropagation_ = other.shareSessionStateInClusterspecPropagation_; + disableThreadSpinning_ = other.disableThreadSpinning_; + shareClusterDevicesInSession_ = other.shareClusterDevicesInSession_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -2856,6 +3015,103 @@ namespace Tensorflow { } } + /// Field number for the "collective_deterministic_sequential_execution" field. + public const int CollectiveDeterministicSequentialExecutionFieldNumber = 6; + private bool collectiveDeterministicSequentialExecution_; + /// + /// If true, make collective op execution order sequential and deterministic + /// for potentially concurrent collective instances. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool CollectiveDeterministicSequentialExecution { + get { return collectiveDeterministicSequentialExecution_; } + set { + collectiveDeterministicSequentialExecution_ = value; + } + } + + /// Field number for the "collective_nccl" field. + public const int CollectiveNcclFieldNumber = 7; + private bool collectiveNccl_; + /// + /// If true, use NCCL for CollectiveOps. This feature is highly + /// experimental. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool CollectiveNccl { + get { return collectiveNccl_; } + set { + collectiveNccl_ = value; + } + } + + /// Field number for the "share_session_state_in_clusterspec_propagation" field. + public const int ShareSessionStateInClusterspecPropagationFieldNumber = 8; + private bool shareSessionStateInClusterspecPropagation_; + /// + /// In the following, session state means the value of a variable, elements + /// in a hash table, or any other resource, accessible by worker sessions + /// held by a TF server. + /// + /// When ClusterSpec propagation is enabled, the value of + /// isolate_session_state is ignored when deciding whether to share session + /// states in a TF server (for backwards compatibility reasons). + /// - If share_session_state_in_clusterspec_propagation is true, the session + /// states are shared. + /// - If share_session_state_in_clusterspec_propagation is false, session + /// states are isolated. + /// + /// When clusterspec propagation is not used, the value of + /// share_session_state_in_clusterspec_propagation is ignored when deciding + /// whether to share session states in a TF server. + /// - If isolate_session_state is true, session states are isolated. + /// - If isolate_session_state is false, session states are shared. + /// + /// TODO(b/129330037): Add a single API that consistently treats + /// isolate_session_state and ClusterSpec propagation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ShareSessionStateInClusterspecPropagation { + get { return shareSessionStateInClusterspecPropagation_; } + set { + shareSessionStateInClusterspecPropagation_ = value; + } + } + + /// Field number for the "disable_thread_spinning" field. + public const int DisableThreadSpinningFieldNumber = 9; + private bool disableThreadSpinning_; + /// + /// If using a direct session, disable spinning while waiting for work in + /// the thread pool. This may result in higher latency for completing ops, + /// but in the case where there is a lot of spinning may result in lower + /// CPU usage. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool DisableThreadSpinning { + get { return disableThreadSpinning_; } + set { + disableThreadSpinning_ = value; + } + } + + /// Field number for the "share_cluster_devices_in_session" field. + public const int ShareClusterDevicesInSessionFieldNumber = 10; + private bool shareClusterDevicesInSession_; + /// + /// When true, WorkerSessions are created with device attributes from the + /// full cluster. + /// This is helpful when a worker wants to partition a graph + /// (for example during a PartitionedCallOp). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ShareClusterDevicesInSession { + get { return shareClusterDevicesInSession_; } + set { + shareClusterDevicesInSession_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as Experimental); @@ -2873,6 +3129,11 @@ namespace Tensorflow { if (ExecutorType != other.ExecutorType) return false; if (RecvBufMaxChunk != other.RecvBufMaxChunk) return false; if (UseNumaAffinity != other.UseNumaAffinity) return false; + if (CollectiveDeterministicSequentialExecution != other.CollectiveDeterministicSequentialExecution) return false; + if (CollectiveNccl != other.CollectiveNccl) return false; + if (ShareSessionStateInClusterspecPropagation != other.ShareSessionStateInClusterspecPropagation) return false; + if (DisableThreadSpinning != other.DisableThreadSpinning) return false; + if (ShareClusterDevicesInSession != other.ShareClusterDevicesInSession) return false; return Equals(_unknownFields, other._unknownFields); } @@ -2883,6 +3144,11 @@ namespace Tensorflow { if (ExecutorType.Length != 0) hash ^= ExecutorType.GetHashCode(); if (RecvBufMaxChunk != 0) hash ^= RecvBufMaxChunk.GetHashCode(); if (UseNumaAffinity != false) hash ^= UseNumaAffinity.GetHashCode(); + if (CollectiveDeterministicSequentialExecution != false) hash ^= CollectiveDeterministicSequentialExecution.GetHashCode(); + if (CollectiveNccl != false) hash ^= CollectiveNccl.GetHashCode(); + if (ShareSessionStateInClusterspecPropagation != false) hash ^= ShareSessionStateInClusterspecPropagation.GetHashCode(); + if (DisableThreadSpinning != false) hash ^= DisableThreadSpinning.GetHashCode(); + if (ShareClusterDevicesInSession != false) hash ^= ShareClusterDevicesInSession.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -2912,6 +3178,26 @@ namespace Tensorflow { output.WriteRawTag(40); output.WriteBool(UseNumaAffinity); } + if (CollectiveDeterministicSequentialExecution != false) { + output.WriteRawTag(48); + output.WriteBool(CollectiveDeterministicSequentialExecution); + } + if (CollectiveNccl != false) { + output.WriteRawTag(56); + output.WriteBool(CollectiveNccl); + } + if (ShareSessionStateInClusterspecPropagation != false) { + output.WriteRawTag(64); + output.WriteBool(ShareSessionStateInClusterspecPropagation); + } + if (DisableThreadSpinning != false) { + output.WriteRawTag(72); + output.WriteBool(DisableThreadSpinning); + } + if (ShareClusterDevicesInSession != false) { + output.WriteRawTag(80); + output.WriteBool(ShareClusterDevicesInSession); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -2932,6 +3218,21 @@ namespace Tensorflow { if (UseNumaAffinity != false) { size += 1 + 1; } + if (CollectiveDeterministicSequentialExecution != false) { + size += 1 + 1; + } + if (CollectiveNccl != false) { + size += 1 + 1; + } + if (ShareSessionStateInClusterspecPropagation != false) { + size += 1 + 1; + } + if (DisableThreadSpinning != false) { + size += 1 + 1; + } + if (ShareClusterDevicesInSession != false) { + size += 1 + 1; + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -2955,6 +3256,21 @@ namespace Tensorflow { if (other.UseNumaAffinity != false) { UseNumaAffinity = other.UseNumaAffinity; } + if (other.CollectiveDeterministicSequentialExecution != false) { + CollectiveDeterministicSequentialExecution = other.CollectiveDeterministicSequentialExecution; + } + if (other.CollectiveNccl != false) { + CollectiveNccl = other.CollectiveNccl; + } + if (other.ShareSessionStateInClusterspecPropagation != false) { + ShareSessionStateInClusterspecPropagation = other.ShareSessionStateInClusterspecPropagation; + } + if (other.DisableThreadSpinning != false) { + DisableThreadSpinning = other.DisableThreadSpinning; + } + if (other.ShareClusterDevicesInSession != false) { + ShareClusterDevicesInSession = other.ShareClusterDevicesInSession; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -2982,6 +3298,26 @@ namespace Tensorflow { UseNumaAffinity = input.ReadBool(); break; } + case 48: { + CollectiveDeterministicSequentialExecution = input.ReadBool(); + break; + } + case 56: { + CollectiveNccl = input.ReadBool(); + break; + } + case 64: { + ShareSessionStateInClusterspecPropagation = input.ReadBool(); + break; + } + case 72: { + DisableThreadSpinning = input.ReadBool(); + break; + } + case 80: { + ShareClusterDevicesInSession = input.ReadBool(); + break; + } } } } @@ -3553,6 +3889,7 @@ namespace Tensorflow { stepStats_ = other.stepStats_ != null ? other.stepStats_.Clone() : null; costGraph_ = other.costGraph_ != null ? other.costGraph_.Clone() : null; partitionGraphs_ = other.partitionGraphs_.Clone(); + functionGraphs_ = other.functionGraphs_.Clone(); _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -3604,6 +3941,28 @@ namespace Tensorflow { get { return partitionGraphs_; } } + /// Field number for the "function_graphs" field. + public const int FunctionGraphsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_functionGraphs_codec + = pb::FieldCodec.ForMessage(34, global::Tensorflow.RunMetadata.Types.FunctionGraphs.Parser); + private readonly pbc::RepeatedField functionGraphs_ = new pbc::RepeatedField(); + /// + /// This is only populated for graphs that are run as functions in TensorFlow + /// V2. There will be an entry below for each function that is traced. + /// The main use cases of the post_optimization_graph and the partition_graphs + /// is to give the caller insight into the graphs that were actually run by the + /// runtime. Additional information (such as those in step_stats) will match + /// these graphs. + /// We also include the pre_optimization_graph since it is usually easier to + /// read, and is helpful in situations where the caller wants to get a high + /// level idea of what the built graph looks like (since the various graph + /// optimization passes might change the structure of the graph significantly). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField FunctionGraphs { + get { return functionGraphs_; } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as RunMetadata); @@ -3620,6 +3979,7 @@ namespace Tensorflow { if (!object.Equals(StepStats, other.StepStats)) return false; if (!object.Equals(CostGraph, other.CostGraph)) return false; if(!partitionGraphs_.Equals(other.partitionGraphs_)) return false; + if(!functionGraphs_.Equals(other.functionGraphs_)) return false; return Equals(_unknownFields, other._unknownFields); } @@ -3629,6 +3989,7 @@ namespace Tensorflow { if (stepStats_ != null) hash ^= StepStats.GetHashCode(); if (costGraph_ != null) hash ^= CostGraph.GetHashCode(); hash ^= partitionGraphs_.GetHashCode(); + hash ^= functionGraphs_.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -3651,6 +4012,7 @@ namespace Tensorflow { output.WriteMessage(CostGraph); } partitionGraphs_.WriteTo(output, _repeated_partitionGraphs_codec); + functionGraphs_.WriteTo(output, _repeated_functionGraphs_codec); if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -3666,6 +4028,7 @@ namespace Tensorflow { size += 1 + pb::CodedOutputStream.ComputeMessageSize(CostGraph); } size += partitionGraphs_.CalculateSize(_repeated_partitionGraphs_codec); + size += functionGraphs_.CalculateSize(_repeated_functionGraphs_codec); if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -3690,6 +4053,7 @@ namespace Tensorflow { CostGraph.MergeFrom(other.CostGraph); } partitionGraphs_.Add(other.partitionGraphs_); + functionGraphs_.Add(other.functionGraphs_); _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -3719,9 +4083,212 @@ namespace Tensorflow { partitionGraphs_.AddEntriesFrom(input, _repeated_partitionGraphs_codec); break; } + case 34: { + functionGraphs_.AddEntriesFrom(input, _repeated_functionGraphs_codec); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the RunMetadata message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + public sealed partial class FunctionGraphs : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FunctionGraphs()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.RunMetadata.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionGraphs() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionGraphs(FunctionGraphs other) : this() { + partitionGraphs_ = other.partitionGraphs_.Clone(); + preOptimizationGraph_ = other.preOptimizationGraph_ != null ? other.preOptimizationGraph_.Clone() : null; + postOptimizationGraph_ = other.postOptimizationGraph_ != null ? other.postOptimizationGraph_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionGraphs Clone() { + return new FunctionGraphs(this); + } + + /// Field number for the "partition_graphs" field. + public const int PartitionGraphsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_partitionGraphs_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.GraphDef.Parser); + private readonly pbc::RepeatedField partitionGraphs_ = new pbc::RepeatedField(); + /// + /// TODO(nareshmodi): Include some sort of function/cache-key identifier? + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField PartitionGraphs { + get { return partitionGraphs_; } + } + + /// Field number for the "pre_optimization_graph" field. + public const int PreOptimizationGraphFieldNumber = 2; + private global::Tensorflow.GraphDef preOptimizationGraph_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.GraphDef PreOptimizationGraph { + get { return preOptimizationGraph_; } + set { + preOptimizationGraph_ = value; + } + } + + /// Field number for the "post_optimization_graph" field. + public const int PostOptimizationGraphFieldNumber = 3; + private global::Tensorflow.GraphDef postOptimizationGraph_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.GraphDef PostOptimizationGraph { + get { return postOptimizationGraph_; } + set { + postOptimizationGraph_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FunctionGraphs); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FunctionGraphs other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!partitionGraphs_.Equals(other.partitionGraphs_)) return false; + if (!object.Equals(PreOptimizationGraph, other.PreOptimizationGraph)) return false; + if (!object.Equals(PostOptimizationGraph, other.PostOptimizationGraph)) return false; + return Equals(_unknownFields, other._unknownFields); } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= partitionGraphs_.GetHashCode(); + if (preOptimizationGraph_ != null) hash ^= PreOptimizationGraph.GetHashCode(); + if (postOptimizationGraph_ != null) hash ^= PostOptimizationGraph.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + partitionGraphs_.WriteTo(output, _repeated_partitionGraphs_codec); + if (preOptimizationGraph_ != null) { + output.WriteRawTag(18); + output.WriteMessage(PreOptimizationGraph); + } + if (postOptimizationGraph_ != null) { + output.WriteRawTag(26); + output.WriteMessage(PostOptimizationGraph); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += partitionGraphs_.CalculateSize(_repeated_partitionGraphs_codec); + if (preOptimizationGraph_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(PreOptimizationGraph); + } + if (postOptimizationGraph_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(PostOptimizationGraph); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FunctionGraphs other) { + if (other == null) { + return; + } + partitionGraphs_.Add(other.partitionGraphs_); + if (other.preOptimizationGraph_ != null) { + if (preOptimizationGraph_ == null) { + preOptimizationGraph_ = new global::Tensorflow.GraphDef(); + } + PreOptimizationGraph.MergeFrom(other.PreOptimizationGraph); + } + if (other.postOptimizationGraph_ != null) { + if (postOptimizationGraph_ == null) { + postOptimizationGraph_ = new global::Tensorflow.GraphDef(); + } + PostOptimizationGraph.MergeFrom(other.PostOptimizationGraph); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + partitionGraphs_.AddEntriesFrom(input, _repeated_partitionGraphs_codec); + break; + } + case 18: { + if (preOptimizationGraph_ == null) { + preOptimizationGraph_ = new global::Tensorflow.GraphDef(); + } + input.ReadMessage(preOptimizationGraph_); + break; + } + case 26: { + if (postOptimizationGraph_ == null) { + postOptimizationGraph_ = new global::Tensorflow.GraphDef(); + } + input.ReadMessage(postOptimizationGraph_); + break; + } + } + } + } + } + } + #endregion } From ed9a8c88a5af3ff9e0d11cb366fc743f492e1d3e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 21 Oct 2019 06:04:48 -0500 Subject: [PATCH 07/41] change WhileContext maximum_iterations to Tensor. --- .../Operations/ControlFlows/WhileContext.cs | 167 ++++++++++-------- 1 file changed, 93 insertions(+), 74 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 1faaa647..c00fc2c7 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -42,7 +42,7 @@ namespace Tensorflow.Operations public override GradLoopState grad_state => _grad_state; public override bool back_prop => _back_prop; - public WhileContext(int? maximum_iterations = null, + public WhileContext(Tensor maximum_iterations = null, int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, @@ -64,7 +64,7 @@ namespace Tensorflow.Operations _grad_state = grad_state; } - private void _init_from_args(int? maximum_iterations, + private void _init_from_args(Tensor maximum_iterations, int parallel_iterations, bool back_prop, bool swap_memory, @@ -107,9 +107,9 @@ namespace Tensorflow.Operations /// /// Add the loop termination condition and body to the graph. /// - public Tensor[] BuildLoop(Func pred, - Func body, - Tensor[] loop_vars, + internal Tensor[] BuildLoop(Func pred, + Func> body, + TItem loop_vars, TensorShape shape_invariants, bool return_same_structure) { @@ -131,88 +131,107 @@ namespace Tensorflow.Operations return packed_exit_vars as Tensor[]; } - private (Tensor[], Tensor[]) _BuildLoop(Func pred, - Func body, - Tensor[] original_loop_vars, - Tensor[] loop_vars, + private Tensor _convert_tensorarray_to_flow(TItem tensor_or_tensor_array) + { + if (tensor_or_tensor_array is TensorArray tensor_array) + return tensor_array.flow; + else if (tensor_or_tensor_array is Tensor tensor) + return tensor; + + throw new NotImplementedException("_convert_tensorarray_to_flow"); + } + + private (Tensor[], Tensor[]) _BuildLoop(Func pred, + Func> body, + TItem original_loop_vars, + TItem loop_vars, TensorShape shape_invariants) { var flat_loop_vars = original_loop_vars; + // Convert TensorArrays to their flow variables + var loop_vars_tensor = nest.map_structure( + _convert_tensorarray_to_flow, + nest.flatten(loop_vars)); + // Let the context know the loop variables so the loop variables // would be added in the outer contexts properly. - _InitializeValues(loop_vars); - var real_vars = loop_vars; - Tensor[] enter_vars = null; - tf_with(ops.control_dependencies(null), delegate + if (loop_vars is Tensor[] real_vars) { - enter_vars = real_vars.Select(x => _Enter(x, - _name, - is_constant: false, - parallel_iterations: _parallel_iterations, - use_input_shape: shape_invariants == null)) - .ToArray(); - - foreach(var x in enter_vars) + _InitializeValues(real_vars); + Tensor[] enter_vars = null; + tf_with(ops.control_dependencies(null), delegate + { + enter_vars = real_vars.Select(x => _Enter(x, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + use_input_shape: shape_invariants == null)) + .ToArray(); + + foreach (var x in enter_vars) + { + x.graph.prevent_feeding(x); + if (_outer_context != null) + _outer_context.AddInnerOp(x.op); + } + }); + + // Finds the closest enclosing non-None control pivot. + var outer_context = _outer_context; + while (outer_context != null) { - x.graph.prevent_feeding(x); - if (_outer_context != null) - _outer_context.AddInnerOp(x.op); + } - }); - // Finds the closest enclosing non-None control pivot. - var outer_context = _outer_context; - while (outer_context != null) - { + _SetShapeInvariants(real_vars, enter_vars, shape_invariants); + + // Fix the control inputs and control flow context of these enter ops. + _FixControlInputsAndContext(enter_vars); + _InitializeValues(enter_vars); + _loop_enters = enter_vars.ToList(); + + var merge_vars = enter_vars + .Select(x => merge(new[] { x, x })) + .ToArray(); + + _pivot_for_pred = merge_vars[0]; + + // Build the graph for pred. + var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); + // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); + var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem))); + _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); + var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) + .ToArray(); + // Build the graph for body. + var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); + // Convert TensorArray flow variables inside the context back into + // their associated TensorArrays for calling the body. + var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); + /*var body_result = body(packed_vars_for_body[0]); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + + // Store body_result to keep track of TensorArrays returned by body + var original_body_result = new[] { body_result }; + // Convert TensorArrays returned by body into their flow variables + var result = new[] { body_result }; + + var next_vars = new List(); + foreach (var (m, v) in zip(merge_vars, result)) + next_vars.Add(_AddNextAndBackEdge(m, v)); + + // Add the exit ops. + var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); + _loop_exits = exit_vars; + + // Exit the loop. + // ExitResult(exit_vars); + return (original_body_result, exit_vars.ToArray());*/ } - _SetShapeInvariants(real_vars, enter_vars, shape_invariants); - - // Fix the control inputs and control flow context of these enter ops. - _FixControlInputsAndContext(enter_vars); - _InitializeValues(enter_vars); - _loop_enters = enter_vars.ToList(); - - var merge_vars = enter_vars - .Select(x => merge(new[] { x, x })) - .ToArray(); - - _pivot_for_pred = merge_vars[0]; - - // Build the graph for pred. - var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); - // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); - var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0])); - _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); - var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) - .ToArray(); - - // Build the graph for body. - var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); - // Convert TensorArray flow variables inside the context back into - // their associated TensorArrays for calling the body. - var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); - var body_result = body(packed_vars_for_body[0]); - var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); - - // Store body_result to keep track of TensorArrays returned by body - var original_body_result = new[] { body_result }; - // Convert TensorArrays returned by body into their flow variables - var result = new[] { body_result }; - - var next_vars = new List(); - foreach (var (m, v) in zip(merge_vars, result)) - next_vars.Add(_AddNextAndBackEdge(m, v)); - - // Add the exit ops. - var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); - _loop_exits = exit_vars; - - // Exit the loop. - // ExitResult(exit_vars); - return (original_body_result, exit_vars.ToArray()); + throw new NotImplementedException(""); } private void _FixControlInputsAndContext(Tensor[] enters) From 07f70f9425a5756345498ef79f2e048e5fdf6d05 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 21 Oct 2019 06:05:17 -0500 Subject: [PATCH 08/41] WhileContext BuildLoop --- .../APIs/tf.control_flow.cs | 4 +-- .../Operations/ControlFlows/LoopVar.cs | 25 +++++++++++++++ .../NnOps/BodyItemInRnnWhileLoop.cs | 32 +++++++++++++++++++ .../Operations/NnOps/rnn.cs | 22 +++++++++---- .../Operations/TensorArray.cs | 2 +- .../Operations/_GraphTensorArray.cs | 6 ++-- .../Operations/control_flow_ops.cs | 32 ++++++++++++++++--- .../Operations/gen_data_flow_ops.cs | 7 ++-- src/TensorFlowNET.Core/Util/nest.py.cs | 1 - src/TensorFlowNET.Core/tensorflow.cs | 8 ++--- test/TensorFlowNET.UnitTest/CSession.cs | 5 ++- .../WhileContextTestCase.cs | 13 ++------ 12 files changed, 117 insertions(+), 40 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs create mode 100644 src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs index 6ed475a9..b2b5574a 100644 --- a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs +++ b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs @@ -37,7 +37,7 @@ namespace Tensorflow public Operation group(T[] inputs, string name = null) where T : ITensorOrOperation => control_flow_ops.group(inputs, name: name); - public Tensor while_loop(Func cond, Func body, Tensor[] loop_vars, + /*public Tensor while_loop(Func cond, Func body, Tensor[] loop_vars, TensorShape shape_invariants = null, int parallel_iterations = 10, bool back_prop = true, @@ -52,7 +52,7 @@ namespace Tensorflow swap_memory: swap_memory, name: name, maximum_iterations: maximum_iterations, - return_same_structure: return_same_structure); + return_same_structure: return_same_structure);*/ public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) => ops.control_dependencies(control_inputs); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs new file mode 100644 index 00000000..fa2fe9d3 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + internal class LoopVar + { + public Tensor Counter { get; } + public TItem[] Items { get; } + public TItem Item { get; } + + public LoopVar(Tensor counter, TItem[] items) + { + Counter = counter; + Items = items; + } + + public LoopVar(Tensor counter, TItem item) + { + Counter = counter; + Item = item; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs new file mode 100644 index 00000000..f0086793 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + internal class BodyItemInRnnWhileLoop + { + /// + /// int32 scalar Tensor. + /// + public Tensor time { get; set; } + /// + /// List of `TensorArray`s that represent the output. + /// + public TensorArray[] output_ta_t { get; set; } + /// + /// nested tuple of vector tensors that represent the state. + /// + public Tensor state { get; set; } + + public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state) + { + this.time = time; + this.output_ta_t = output_ta_t; + this.state = state; + } + + public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item) + => (item.time, item.output_ta_t, item.state); + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 8e7425e5..e058c077 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -145,7 +145,7 @@ namespace Tensorflow.Operations { var ta = new TensorArray(dtype: dtype_, size: time_steps, - element_shape: new[] { element_shape }, + element_shape: element_shape, tensor_array_name: base_name + name); return ta; }; @@ -178,19 +178,29 @@ namespace Tensorflow.Operations // Make sure that we run at least 1 step, if necessary, to ensure // the TensorArrays pick up the dynamic shape. - Tensor loop_bound; + Tensor loop_bound = null; if (in_graph_mode) loop_bound = math_ops.minimum( time_steps, math_ops.maximum(1, max_sequence_length)); - /*Func cond = (ctime) => + Func cond = (item) => { - return null; + return time < loop_bound; }; - control_flow_ops.while_loop( + // Take a time step of the dynamic RNN. + Func _time_step = (item) => + { + return item; + }; + + control_flow_ops.while_loop( cond: cond, - body = );*/ + body: _time_step, + loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), + parallel_iterations: parallel_iterations, + maximum_iterations: time_steps, + swap_memory: swap_memory); throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/Operations/TensorArray.cs b/src/TensorFlowNET.Core/Operations/TensorArray.cs index 7251bf85..60e1bde5 100644 --- a/src/TensorFlowNET.Core/Operations/TensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/TensorArray.cs @@ -39,7 +39,7 @@ namespace Tensorflow.Operations public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, - bool infer_shape = true, TensorShape[] element_shape = null, + bool infer_shape = true, TensorShape element_shape = null, bool colocate_with_first_write_call = true, string name = null) { _implementation = new _GraphTensorArray(dtype, diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index bd919ad8..5a667560 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -44,7 +44,7 @@ namespace Tensorflow.Operations public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, - bool infer_shape = true, TensorShape[] element_shape = null, + bool infer_shape = true, TensorShape element_shape = null, bool colocate_with_first_write_call = true, string name = null) { clear_after_read = clear_after_read ?? true; @@ -68,7 +68,7 @@ namespace Tensorflow.Operations else { _infer_shape = true; - _element_shape = new List { }; + _element_shape = new List { element_shape }; } tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => @@ -135,7 +135,7 @@ namespace Tensorflow.Operations var ta = new TensorArray(_dtype, infer_shape:_infer_shape, - element_shape: _element_shape.ToArray(), + element_shape: _element_shape[0], dynamic_size: _dynamic_size, handle: _handle, flow: flow_out, diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index e8b5f0eb..27e43153 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -485,7 +485,7 @@ namespace Tensorflow }); } - public static Tensor[] _convert_flows_to_tensorarrays(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) + public static Tensor[] _convert_flows_to_tensorarrays(T tensors_or_tensorarrays, Tensor[] tensors_or_flows) { // zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray(); return tensors_or_flows; @@ -591,18 +591,18 @@ namespace Tensorflow /// /// /// - public static Tensor while_loop(Func cond, Func body, Tensor[] loop_vars, + public static Tensor while_loop(Func cond, Func body, TItem loop_vars, TensorShape shape_invariants = null, int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, string name = null, - int? maximum_iterations = null, + Tensor maximum_iterations = null, bool return_same_structure = false) { tf_with(ops.name_scope(name, "while", loop_vars), scope => { - if (loop_vars == null || loop_vars.Length == 0) + if (loop_vars == null) throw new ValueError("No loop variables provided"); if (cond == null) throw new ValueError("cond must be callable."); @@ -611,6 +611,28 @@ namespace Tensorflow if (parallel_iterations < 1) throw new ValueError("parallel_iterations must be a positive integer."); + var try_to_pack = loop_vars is Tensor && !return_same_structure; + var counter = constant_op.constant(0, dtype: maximum_iterations.dtype, name: "iteration_counter"); + var orig_cond = cond; + var orig_body = body; + + LoopVar loop_vars_1 = null; + Func> body_buildloop = null; + Func cond_buildloop = null; + + if (try_to_pack) + { + + } + else + { + loop_vars_1 = new LoopVar(counter, loop_vars); + cond_buildloop = (i, lv) => + math_ops.logical_and(i < maximum_iterations, orig_cond(lv)); + body_buildloop = (i, lv) => new LoopVar(i + 1, orig_body(lv)); + } + try_to_pack = false; + var loop_context = new WhileContext( maximum_iterations: maximum_iterations, parallel_iterations: parallel_iterations, @@ -620,7 +642,7 @@ namespace Tensorflow if (loop_context.outer_context == null) ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); - var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, + var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants, return_same_structure); if (maximum_iterations != null) diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index fa194934..71e9bbab 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -28,12 +28,9 @@ namespace Tensorflow } public static (Tensor, Tensor) tensor_array_v3(T size, TF_DataType dtype = TF_DataType.DtInvalid, - TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, - bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) + TensorShape element_shape = null, bool dynamic_size = false, bool clear_after_read = true, + bool identical_element_shapes = false, string tensor_array_name = "", string name = null) { - if (tensor_array_name == null) - tensor_array_name = string.Empty; - var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new { size, diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index b3ae594f..54ff358a 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -223,7 +223,6 @@ namespace Tensorflow.Util private static void _flatten_recursive(T obj, List list) { - switch(obj) { case IDictionary dict: diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 39fd2ac9..a512663e 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -93,14 +93,14 @@ namespace Tensorflow return new Session().as_default(); } - public Session Session(Graph graph, SessionOptions opts = null) + public Session Session(Graph graph, ConfigProto config = null) { - return new Session(graph, opts: opts).as_default(); + return new Session(graph, config: config).as_default(); } - public Session Session(SessionOptions opts) + public Session Session(ConfigProto config) { - return new Session(null, opts).as_default(); + return new Session(null, config).as_default(); } public void __init__() diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs index fa293288..e9ed7784 100644 --- a/test/TensorFlowNET.UnitTest/CSession.cs +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -25,9 +25,8 @@ namespace TensorFlowNET.UnitTest { lock (Locks.ProcessWide) { - var opts = new SessionOptions(); - opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); - session_ = new Session(graph, opts, s); + var config = new ConfigProto {InterOpParallelismThreads = 4}; + session_ = new Session(graph, config, s); } } diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs index 72dd83ea..80ff71db 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs @@ -18,10 +18,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test var i = constant_op.constant(0, name: "i"); var c = new Func(x => tf.less(x, 10, name: "c")); var b = new Func(x => tf.add(x, 1, name: "c")); - var r = control_flow_ops.while_loop(c, b, new[] { i }); + var r = control_flow_ops.while_loop(c, b, i); } - private void _testWhileContextHelper(int? maximum_iterations = null) + private void _testWhileContextHelper(int maximum_iterations) { // TODO: implement missing code dependencies using (var sess = this.cached_session()) @@ -30,7 +30,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test var c = new Func(x => gen_math_ops.less(x, 10, name: "c")); var b = new Func(x => gen_math_ops.add(x, 1, name: "c")); control_flow_ops.while_loop( - c, b, new[] { i }, maximum_iterations: maximum_iterations); + c, b, i , maximum_iterations: tf.constant(maximum_iterations)); foreach (Operation op in sess.graph.get_operations()) { var control_flow_context = op._get_control_flow_context(); @@ -42,13 +42,6 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test } } - [Ignore("TODO")] - [TestMethod] - public void testWhileContext() - { - _testWhileContextHelper(); - } - [Ignore("TODO")] [TestMethod] public void testWhileContextWithMaximumIterations() From f7cee4bd08ade3cf018e89ab3412a04cb19fa581 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 21 Oct 2019 16:51:46 -0500 Subject: [PATCH 09/41] make Tensor.buffer be public. --- src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 6 +++--- src/TensorFlowNET.Core/Tensors/Tensor.cs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index f7b7d424..598969b7 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.11.8 + 0.11.8.1 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true @@ -17,7 +17,7 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.11.8.0 + 0.11.8.1 Changes since v0.10.0: 1. Upgrade NumSharp to v0.20.3. 2. Add DisposableObject class to manage object lifetime. @@ -34,7 +34,7 @@ Docs: https://tensorflownet.readthedocs.io 13. Return VariableV1 instead of RefVariable. 14. Add Tensor overload to GradientDescentOptimizer. 7.3 - 0.11.8.0 + 0.11.8.1 LICENSE true true diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 01290644..161696a1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -97,7 +97,7 @@ namespace Tensorflow [JsonIgnore] #endif public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; - private IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); + public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); #if SERIALIZABLE [JsonIgnore] From 547c4e6bf48fceb8266d0cb88d3a13ec4171606a Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 21 Oct 2019 21:08:16 -0500 Subject: [PATCH 10/41] tf.while_loop, add ICanBeFlattened #348 --- .../Operations/ControlFlows/LoopVar.cs | 17 +++++++----- .../Operations/ControlFlows/WhileContext.cs | 8 +++--- src/TensorFlowNET.Core/Operations/IFlatten.cs | 11 ++++++++ .../NnOps/BodyItemInRnnWhileLoop.cs | 10 ++++++- .../Operations/control_flow_ops.cs | 2 +- .../TensorFlowNET.Core.csproj | 26 +++++-------------- src/TensorFlowNET.Core/Util/nest.py.cs | 6 +++++ 7 files changed, 48 insertions(+), 32 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/IFlatten.cs diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs index fa2fe9d3..c313739b 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs @@ -4,22 +4,25 @@ using System.Text; namespace Tensorflow.Operations { - internal class LoopVar + internal class LoopVar : ICanBeFlattened { public Tensor Counter { get; } - public TItem[] Items { get; } public TItem Item { get; } - public LoopVar(Tensor counter, TItem[] items) + public LoopVar(Tensor counter, TItem item) { Counter = counter; - Items = items; + Item = item; } - public LoopVar(Tensor counter, TItem item) + public object[] Flatten() { - Counter = counter; - Item = item; + var elements = new List { Counter }; + if (typeof(TItem).GetInterface(typeof(ICanBeFlattened).Name) != null) + elements.AddRange((Item as ICanBeFlattened).Flatten()); + else + elements.Add(Item); + return elements.ToArray(); } } } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index c00fc2c7..462aca25 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -109,7 +109,7 @@ namespace Tensorflow.Operations /// internal Tensor[] BuildLoop(Func pred, Func> body, - TItem loop_vars, + LoopVar loop_vars, TensorShape shape_invariants, bool return_same_structure) { @@ -143,8 +143,8 @@ namespace Tensorflow.Operations private (Tensor[], Tensor[]) _BuildLoop(Func pred, Func> body, - TItem original_loop_vars, - TItem loop_vars, + LoopVar original_loop_vars, + LoopVar loop_vars, TensorShape shape_invariants) { var flat_loop_vars = original_loop_vars; @@ -152,7 +152,7 @@ namespace Tensorflow.Operations // Convert TensorArrays to their flow variables var loop_vars_tensor = nest.map_structure( _convert_tensorarray_to_flow, - nest.flatten(loop_vars)); + nest.flatten2(loop_vars)); // Let the context know the loop variables so the loop variables // would be added in the outer contexts properly. diff --git a/src/TensorFlowNET.Core/Operations/IFlatten.cs b/src/TensorFlowNET.Core/Operations/IFlatten.cs new file mode 100644 index 00000000..305dc72e --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/IFlatten.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public interface ICanBeFlattened + { + object[] Flatten(); + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs index f0086793..9ffea25c 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow.Operations { - internal class BodyItemInRnnWhileLoop + internal class BodyItemInRnnWhileLoop : ICanBeFlattened { /// /// int32 scalar Tensor. @@ -28,5 +28,13 @@ namespace Tensorflow.Operations public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item) => (item.time, item.output_ta_t, item.state); + + public object[] Flatten() + { + var elements = new List { time }; + elements.AddRange(output_ta_t); + elements.Add(state); + return elements.ToArray(); + } } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 27e43153..181b7e71 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -642,7 +642,7 @@ namespace Tensorflow if (loop_context.outer_context == null) ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); - var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants, + var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants, return_same_structure); if (maximum_iterations != null) diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 598969b7..33bba3dc 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.11.8.1 + 0.12.0 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true @@ -16,25 +16,13 @@ https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. -Docs: https://tensorflownet.readthedocs.io - 0.11.8.1 - Changes since v0.10.0: -1. Upgrade NumSharp to v0.20.3. -2. Add DisposableObject class to manage object lifetime. -3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. -4. Change tensorflow to non-static class in order to execute some initialization process. -5. Overload session.run(), make syntax simpler. -6. Add Local Response Normalization. -7. Add tf.image related APIs. -8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. -9. MultiThread is safe. -10. Support n-dim indexing for tensor. -11. Add RegisterNoGradients -12. Add CumsumGrad, BroadcastToGrad. -13. Return VariableV1 instead of RefVariable. -14. Add Tensor overload to GradientDescentOptimizer. +Building, training and infering deep learning models. +https://tensorflownet.readthedocs.io + 0.12.0.0 + Changes since v0.11.0: + 7.3 - 0.11.8.1 + 0.12.0.0 LICENSE true true diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 54ff358a..9b0af4f6 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -19,6 +19,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; using NumSharp; +using Tensorflow.Operations; namespace Tensorflow.Util { @@ -221,6 +222,11 @@ namespace Tensorflow.Util return list; } + public static object[] flatten2(ICanBeFlattened structure) + { + return structure.Flatten(); + } + private static void _flatten_recursive(T obj, List list) { switch(obj) From 4193ac662bdb32e977b9c38e0c408ea54fc61d0d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 24 Oct 2019 13:40:50 -0500 Subject: [PATCH 11/41] move interfaces to same folder --- .../{ => Interfaces}/IObjectLife.cs | 0 .../{ => Interfaces}/ITensorOrOperation.cs | 0 .../Interfaces/ITensorOrTensorArray.cs | 27 +++++++++++++++++++ 3 files changed, 27 insertions(+) rename src/TensorFlowNET.Core/{ => Interfaces}/IObjectLife.cs (100%) rename src/TensorFlowNET.Core/{ => Interfaces}/ITensorOrOperation.cs (100%) create mode 100644 src/TensorFlowNET.Core/Interfaces/ITensorOrTensorArray.cs diff --git a/src/TensorFlowNET.Core/IObjectLife.cs b/src/TensorFlowNET.Core/Interfaces/IObjectLife.cs similarity index 100% rename from src/TensorFlowNET.Core/IObjectLife.cs rename to src/TensorFlowNET.Core/Interfaces/IObjectLife.cs diff --git a/src/TensorFlowNET.Core/ITensorOrOperation.cs b/src/TensorFlowNET.Core/Interfaces/ITensorOrOperation.cs similarity index 100% rename from src/TensorFlowNET.Core/ITensorOrOperation.cs rename to src/TensorFlowNET.Core/Interfaces/ITensorOrOperation.cs diff --git a/src/TensorFlowNET.Core/Interfaces/ITensorOrTensorArray.cs b/src/TensorFlowNET.Core/Interfaces/ITensorOrTensorArray.cs new file mode 100644 index 00000000..a6f30ceb --- /dev/null +++ b/src/TensorFlowNET.Core/Interfaces/ITensorOrTensorArray.cs @@ -0,0 +1,27 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + /// + /// in order to limit function return value + /// is Tensor or TensorArray + /// + public interface ITensorOrTensorArray + { + + } +} From 381bf18cb406bc763fba5019ac0e01e67196bc9d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 24 Oct 2019 13:41:18 -0500 Subject: [PATCH 12/41] move TensorArray to Tensors folder --- .../{Operations => Tensors}/TensorArray.cs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) rename src/TensorFlowNET.Core/{Operations => Tensors}/TensorArray.cs (94%) diff --git a/src/TensorFlowNET.Core/Operations/TensorArray.cs b/src/TensorFlowNET.Core/Tensors/TensorArray.cs similarity index 94% rename from src/TensorFlowNET.Core/Operations/TensorArray.cs rename to src/TensorFlowNET.Core/Tensors/TensorArray.cs index 60e1bde5..f84072a8 100644 --- a/src/TensorFlowNET.Core/Operations/TensorArray.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorArray.cs @@ -17,8 +17,9 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Operations; -namespace Tensorflow.Operations +namespace Tensorflow { /// /// TensorArray is designed to hide an underlying implementation object @@ -29,9 +30,9 @@ namespace Tensorflow.Operations /// `while_loop` and `map_fn`. It supports gradient back-propagation via special /// "flow" control flow dependencies. /// - public class TensorArray + public class TensorArray : ITensorOrTensorArray { - _GraphTensorArray _implementation; + internal _GraphTensorArray _implementation; public TF_DataType dtype => _implementation._dtype; public Tensor handle => _implementation._handle; From 61bc941ffcf55f16e1f1c2c6ddfe26bd108ca9ee Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 24 Oct 2019 13:43:46 -0500 Subject: [PATCH 13/41] add IHaveFlatten --- src/TensorFlowNET.Core/{Operations => Util}/IFlatten.cs | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/TensorFlowNET.Core/{Operations => Util}/IFlatten.cs (100%) diff --git a/src/TensorFlowNET.Core/Operations/IFlatten.cs b/src/TensorFlowNET.Core/Util/IFlatten.cs similarity index 100% rename from src/TensorFlowNET.Core/Operations/IFlatten.cs rename to src/TensorFlowNET.Core/Util/IFlatten.cs From 3425aa4a292da4ebc0a148c6fe8518e195203142 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 24 Oct 2019 13:44:04 -0500 Subject: [PATCH 14/41] LoopVar --- .../ControlFlows/ControlFlowContext.cs | 27 ++- .../Operations/ControlFlows/LoopVar.cs | 5 + .../Operations/ControlFlows/WhileContext.cs | 211 +++++++++++------- 3 files changed, 160 insertions(+), 83 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index c076cbc7..8a624df2 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Operations.ControlFlows; using static Tensorflow.ControlFlowContextDef; +using static Tensorflow.Binding; namespace Tensorflow.Operations { @@ -72,6 +73,7 @@ namespace Tensorflow.Operations public ControlFlowContext() { _context_stack = new Stack(); + _external_values = new Dictionary(); } public string name { get => _name; } @@ -180,6 +182,11 @@ namespace Tensorflow.Operations public virtual bool back_prop => throw new NotImplementedException("abstract method"); + /// + /// Add `val` to the current context and its outer context recursively. + /// + /// + /// public virtual Tensor AddValue(Tensor val) { // to be overridden @@ -203,7 +210,25 @@ namespace Tensorflow.Operations /// protected virtual void _AddOpInternal(Operation op) { - + if (op.name == "rnn/while/Less") + { + + } + + if(op == null) + { + throw new NotImplementedException(""); + } + else + { + foreach(var index in range(len(op.inputs))) + { + var x = op.inputs[index]; + var real_x = AddValue(x); + if (real_x != x) + op._update_input(index, real_x); + } + } } protected bool OpInContext(Operation op) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs index c313739b..d49d5abf 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs @@ -24,5 +24,10 @@ namespace Tensorflow.Operations elements.Add(Item); return elements.ToArray(); } + + public static implicit operator (Tensor, TItem)(LoopVar loopVar) + { + return (loopVar.Counter, loopVar.Item); + } } } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 462aca25..b40dae11 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -71,6 +71,8 @@ namespace Tensorflow.Operations string name) { _name = ops.get_default_graph().unique_name(name); + _maximum_iterations = maximum_iterations; + _parallel_iterations = parallel_iterations; _back_prop = back_prop; _swap_memory = swap_memory; _loop_exits = new List(); @@ -107,18 +109,27 @@ namespace Tensorflow.Operations /// /// Add the loop termination condition and body to the graph. /// - internal Tensor[] BuildLoop(Func pred, - Func> body, + internal Tensor[] BuildLoop(Func, Tensor> pred, + Func, LoopVar> body, LoopVar loop_vars, - TensorShape shape_invariants, + TensorShape[] shape_invariants, bool return_same_structure) { // Keep original_loop_vars to identify which are TensorArrays var original_loop_vars = loop_vars; // Convert TensorArrays to their flow variables + var loop_vars_tensors = nest.flatten2(loop_vars) + .Select(x => _convert_tensorarray_to_flow(x)) + .ToArray(); + + if (shape_invariants == null) + shape_invariants = loop_vars_tensors + .Select(x => _get_shape_invariant(x as Tensor)) + .ToArray(); + Enter(); var(original_body_result, exit_vars) = _BuildLoop( - pred, body, original_loop_vars, loop_vars, shape_invariants); + pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); Exit(); var flat_result = original_body_result; @@ -131,7 +142,7 @@ namespace Tensorflow.Operations return packed_exit_vars as Tensor[]; } - private Tensor _convert_tensorarray_to_flow(TItem tensor_or_tensor_array) + private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array) { if (tensor_or_tensor_array is TensorArray tensor_array) return tensor_array.flow; @@ -141,97 +152,116 @@ namespace Tensorflow.Operations throw new NotImplementedException("_convert_tensorarray_to_flow"); } - private (Tensor[], Tensor[]) _BuildLoop(Func pred, - Func> body, - LoopVar original_loop_vars, - LoopVar loop_vars, - TensorShape shape_invariants) + private TensorShape _get_shape_invariant(Tensor var, int[] shape = null) { - var flat_loop_vars = original_loop_vars; + return var.TensorShape; + } - // Convert TensorArrays to their flow variables - var loop_vars_tensor = nest.map_structure( - _convert_tensorarray_to_flow, - nest.flatten2(loop_vars)); + /// + /// Add the loop termination condition and body to the graph. + /// + /// + /// + /// + /// + /// + /// + /// + private (Tensor[], Tensor[]) _BuildLoop(Func, Tensor> pred, + Func, LoopVar> body, + LoopVar original_loop_vars, + Tensor[] loop_vars, + TensorShape[] shape_invariants) + { + var flat_loop_vars = nest.flatten2(original_loop_vars) + .Select(x => (ITensorOrTensorArray)x) + .ToArray(); // Let the context know the loop variables so the loop variables // would be added in the outer contexts properly. - if (loop_vars is Tensor[] real_vars) + _InitializeValues(loop_vars); + var real_vars = loop_vars; + Tensor[] enter_vars = null; + tf_with(ops.control_dependencies(null), delegate { - _InitializeValues(real_vars); - Tensor[] enter_vars = null; - tf_with(ops.control_dependencies(null), delegate - { - enter_vars = real_vars.Select(x => _Enter(x, - _name, - is_constant: false, - parallel_iterations: _parallel_iterations, - use_input_shape: shape_invariants == null)) - .ToArray(); - - foreach (var x in enter_vars) - { - x.graph.prevent_feeding(x); - if (_outer_context != null) - _outer_context.AddInnerOp(x.op); - } - }); - - // Finds the closest enclosing non-None control pivot. - var outer_context = _outer_context; - while (outer_context != null) + enter_vars = real_vars.Select(x => _Enter(x, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + use_input_shape: shape_invariants == null)) + .ToArray(); + + foreach (var x in enter_vars) { - + x.graph.prevent_feeding(x); + if (_outer_context != null) + _outer_context.AddInnerOp(x.op); } + }); - _SetShapeInvariants(real_vars, enter_vars, shape_invariants); - - // Fix the control inputs and control flow context of these enter ops. - _FixControlInputsAndContext(enter_vars); - _InitializeValues(enter_vars); - _loop_enters = enter_vars.ToList(); - - var merge_vars = enter_vars - .Select(x => merge(new[] { x, x })) - .ToArray(); + // Finds the closest enclosing non-None control pivot. + var outer_context = _outer_context; + object control_pivot = null; + while (outer_context != null && control_pivot == null) + { - _pivot_for_pred = merge_vars[0]; + } - // Build the graph for pred. - var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); - // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); - var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem))); - _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); - var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) - .ToArray(); + if (control_pivot != null) + { - // Build the graph for body. - var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); - // Convert TensorArray flow variables inside the context back into - // their associated TensorArrays for calling the body. - var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); - /*var body_result = body(packed_vars_for_body[0]); - var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); - - // Store body_result to keep track of TensorArrays returned by body - var original_body_result = new[] { body_result }; - // Convert TensorArrays returned by body into their flow variables - var result = new[] { body_result }; - - var next_vars = new List(); - foreach (var (m, v) in zip(merge_vars, result)) - next_vars.Add(_AddNextAndBackEdge(m, v)); - - // Add the exit ops. - var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); - _loop_exits = exit_vars; - - // Exit the loop. - // ExitResult(exit_vars); - return (original_body_result, exit_vars.ToArray());*/ } - throw new NotImplementedException(""); + _SetShapeInvariants(real_vars, enter_vars, shape_invariants); + + // Fix the control inputs and control flow context of these enter ops. + _FixControlInputsAndContext(enter_vars); + _InitializeValues(enter_vars); + _loop_enters = enter_vars.ToList(); + + var merge_vars = enter_vars + .Select(x => merge(new[] { x, x })) + .ToArray(); + + _pivot_for_pred = merge_vars[0]; + + // Build the graph for pred. + var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); + //var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays, expand_composites: true); + var packed_vars = new LoopVar((Tensor)merge_vars_with_tensor_arrays[0], + (TItem)(object)new BodyItemInRnnWhileLoop((Tensor)merge_vars_with_tensor_arrays[1], + new[] { (TensorArray)merge_vars_with_tensor_arrays[2] }, + (Tensor)merge_vars_with_tensor_arrays[3])); + var pp = pred(packed_vars); + var c = ops.convert_to_tensor(pp); + _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); + var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) + .ToArray(); + + // Build the graph for body. + var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); + // Convert TensorArray flow variables inside the context back into + // their associated TensorArrays for calling the body. + var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); + var body_result = body(original_loop_vars); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + + // Store body_result to keep track of TensorArrays returned by body + var original_body_result = new[] { body_result }; + // Convert TensorArrays returned by body into their flow variables + var result = new[] { body_result }; + + var next_vars = new List(); + //foreach (var (m, v) in zip(merge_vars, result)) + //next_vars.Add(_AddNextAndBackEdge(m, v)); + + // Add the exit ops. + var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); + _loop_exits = exit_vars; + + // Exit the loop. + // ExitResult(exit_vars); + return (null, exit_vars.ToArray()); } private void _FixControlInputsAndContext(Tensor[] enters) @@ -258,6 +288,23 @@ namespace Tensorflow.Operations _values.Add(x.name); } + public override Tensor AddValue(Tensor val) + { + var result = val; + var new_value = _values.Contains(val.name); + new_value &= val.op._get_control_flow_context() != this; + if (new_value) + throw new NotImplementedException(""); + else + { + var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null; + if (actual_val != null) + result = actual_val as Tensor; + } + + return result; + } + public override WhileContext GetWhileContext() { return this; From 8e054afd8db31a70173e2a7ab5f3b712cbcb2839 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 24 Oct 2019 13:44:33 -0500 Subject: [PATCH 15/41] add zip() --- src/TensorFlowNET.Core/Binding.Util.cs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 150fa89a..c70af1fd 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -165,6 +165,12 @@ namespace Tensorflow yield return (t1[i], t2[i]); } + public static IEnumerable<(T1, T2, T3)> zip(IList t1, IList t2, IList t3) + { + for (int i = 0; i < t1.Count; i++) + yield return (t1[i], t2[i], t3[i]); + } + public static IEnumerable<(T1, T2)> zip(NDArray t1, NDArray t2) where T1: unmanaged where T2: unmanaged @@ -203,6 +209,7 @@ namespace Tensorflow yield return (i, values[i]); } + [DebuggerStepThrough] public static Dictionary ConvertToDict(object dyn) { var dictionary = new Dictionary(); From a65d881213b7ab11a3223fef3996a10005ec3627 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 24 Oct 2019 13:45:02 -0500 Subject: [PATCH 16/41] CheckInputFromValidContext --- .../Operations/NnOps/rnn.cs | 9 ++-- .../Operations/Operation.Control.cs | 5 +- .../Operations/Operation.cs | 1 + .../Operations/_GraphTensorArray.cs | 10 ++-- .../Operations/control_flow_ops.cs | 52 +++++++++++++------ .../Operations/control_flow_util.py.cs | 22 ++++++++ .../Operations/gen_math_ops.cs | 2 + .../Operations/tensor_array_ops.cs | 33 ++++++++++++ .../TensorFlowNET.Core.csproj | 3 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 4 +- src/TensorFlowNET.Core/Util/nest.py.cs | 9 ++-- 11 files changed, 117 insertions(+), 33 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/tensor_array_ops.cs diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index e058c077..475dd0ff 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -172,7 +172,8 @@ namespace Tensorflow.Operations for (int i = 0; i < input_ta.Count; i++) { - var (ta, input_) = (input_ta[0], flat_input[0]); + var (ta, input_) = (input_ta[i], flat_input[i]); + ta.unstack(input_); } } @@ -185,16 +186,16 @@ namespace Tensorflow.Operations Func cond = (item) => { - return time < loop_bound; + return item.time < loop_bound; }; // Take a time step of the dynamic RNN. Func _time_step = (item) => { - return item; + throw new NotImplementedException(""); }; - control_flow_ops.while_loop( + control_flow_ops.while_loop( cond: cond, body: _time_step, loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 2f61f954..5e93cfd0 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -30,10 +30,9 @@ namespace Tensorflow /// public void _control_flow_post_processing() { - foreach(var input_tensor in inputs) + foreach(Tensor input_tensor in inputs) { - //TODO: implement below code dependency - //control_flow_util.CheckInputFromValidContext(this, input_tensor.op); + control_flow_util.CheckInputFromValidContext(this, input_tensor.op); } if (_control_flow_context != null) diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 3b40c95a..db001e51 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -23,6 +23,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using Tensorflow.Util; +using static Tensorflow.Binding; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index 5a667560..ebc88230 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -25,6 +25,7 @@ namespace Tensorflow.Operations internal class _GraphTensorArray { internal TF_DataType _dtype; + public TF_DataType dtype => _dtype; /// /// Used to keep track of what tensors the TensorArray should be @@ -32,14 +33,17 @@ namespace Tensorflow.Operations /// first tensor written to it. /// bool _colocate_with_first_write_call; + public bool colocate_with_first_write_call => _colocate_with_first_write_call; bool _infer_shape; - bool _dynamic_size; - List _element_shape; + public bool infer_shape => _infer_shape; + public bool _dynamic_size; + public List _element_shape; - List _colocate_with; + public List _colocate_with; internal Tensor _handle; + public Tensor handle => _handle; internal Tensor _flow; public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 181b7e71..6c286fc1 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -21,6 +21,7 @@ using Tensorflow.Operations; using Tensorflow.Operations.ControlFlows; using util = Tensorflow.control_flow_util; using static Tensorflow.Binding; +using Tensorflow.Util; namespace Tensorflow { @@ -251,12 +252,16 @@ namespace Tensorflow return gen_array_ops.identity(data, name: name); } - public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null) + public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape[] shapes = null) { if (shapes == null) return; - throw new NotImplementedException("_SetShapeInvariants"); + var flat_shapes = nest.flatten2(shapes); + foreach (var (inp, var, shape) in zip(input_vars, enter_vars, flat_shapes)) + { + var.set_shape(shape); + } } /// @@ -428,12 +433,12 @@ namespace Tensorflow .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) .ToArray(); - merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); + var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); - return merges[0]; + return new Tensor(IntPtr.Zero); }); } @@ -473,22 +478,28 @@ namespace Tensorflow var res_f_flat = res_f; var merges = zip(res_f_flat, res_t_flat) - .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) + .Select(pair => merge(new [] { pair.Item1, pair.Item2 })) .ToArray(); - merges = _convert_flows_to_tensorarrays(orig_res_t, merges); + var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); - return merges; + return new[] { new Tensor(IntPtr.Zero) }; }); } - public static Tensor[] _convert_flows_to_tensorarrays(T tensors_or_tensorarrays, Tensor[] tensors_or_flows) + public static ITensorOrTensorArray[] _convert_flows_to_tensorarrays(ITensorOrTensorArray[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) { - // zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray(); - return tensors_or_flows; + return zip(tensors_or_tensorarrays, tensors_or_flows).Select(x => + { + var (ta, t_or_flow) = (x.Item1, x.Item2); + if (ta is TensorArray ta_1) + return tensor_array_ops.build_ta_with_new_flow(ta_1, t_or_flow) as ITensorOrTensorArray; + else + return t_or_flow as ITensorOrTensorArray; + }).ToArray(); } /// @@ -592,7 +603,7 @@ namespace Tensorflow /// /// public static Tensor while_loop(Func cond, Func body, TItem loop_vars, - TensorShape shape_invariants = null, + TensorShape[] shape_invariants = null, int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, @@ -617,8 +628,8 @@ namespace Tensorflow var orig_body = body; LoopVar loop_vars_1 = null; - Func> body_buildloop = null; - Func cond_buildloop = null; + Func, LoopVar> body_buildloop = null; + Func, Tensor> cond_buildloop = null; if (try_to_pack) { @@ -627,9 +638,18 @@ namespace Tensorflow else { loop_vars_1 = new LoopVar(counter, loop_vars); - cond_buildloop = (i, lv) => - math_ops.logical_and(i < maximum_iterations, orig_cond(lv)); - body_buildloop = (i, lv) => new LoopVar(i + 1, orig_body(lv)); + cond_buildloop = (item) => + { + var (i, lv) = (item.Counter, item.Item); + var oc = orig_cond(lv); + return math_ops.logical_and(i < maximum_iterations, oc); + }; + + body_buildloop = (item) => + { + var (i, lv) = (item.Counter, item.Item); + return new LoopVar(i + 1, orig_body(lv)); + }; } try_to_pack = false; diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs index 4ae03e42..5377eb5b 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs @@ -14,7 +14,9 @@ limitations under the License. ******************************************************************************/ +using System; using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { @@ -53,5 +55,25 @@ namespace Tensorflow ctxt = ctxt.outer_context; return ctxt; } + + public static void CheckInputFromValidContext(Operation op, Operation input_op) + { + var op_ctxt = op._get_control_flow_context(); + var input_ctxt = GetOutputContext(input_op); + var valid = false; + if (input_ctxt == null) + valid = true; + else if (op_ctxt == input_ctxt) + valid = true; + else + { + throw new NotImplementedException(""); + } + + if (!valid) + { + throw new NotImplementedException(""); + } + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 7192dc57..e1225cc9 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using static Tensorflow.Binding; + namespace Tensorflow { public static class gen_math_ops diff --git a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs new file mode 100644 index 00000000..8ce3b5c7 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class tensor_array_ops + { + /// + /// Builds a TensorArray with a new `flow` tensor. + /// + /// + /// + /// + public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) + { + var impl = old_ta._implementation; + + var new_ta = new TensorArray( + dtype: impl.dtype, + handle: impl.handle, + flow: flow, + infer_shape: impl.infer_shape, + colocate_with_first_write_call: impl.colocate_with_first_write_call); + + var new_impl = new_ta._implementation; + new_impl._dynamic_size = impl._dynamic_size; + new_impl._colocate_with = impl._colocate_with; + new_impl._element_shape = impl._element_shape; + return new_ta; + } + } +} diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 33bba3dc..fbad178e 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -20,7 +20,8 @@ Building, training and infering deep learning models. https://tensorflownet.readthedocs.io 0.12.0.0 Changes since v0.11.0: - +1: Add ICanBeFlattened for nest.flatten2. +2: 7.3 0.12.0.0 LICENSE diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 161696a1..943edaaf 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -39,7 +39,7 @@ namespace Tensorflow /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// [SuppressMessage("ReSharper", "ConvertToAutoProperty")] - public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike + public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike, ITensorOrTensorArray { private readonly int _id; private readonly Operation _op; @@ -178,7 +178,7 @@ namespace Tensorflow /// public void set_shape(TensorShape shape) { - this.shape = shape.rank > 0 ? shape.dims : null; + this.shape = shape.rank >= 0 ? shape.dims : null; } /// diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 9b0af4f6..28f9ba03 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -223,9 +223,10 @@ namespace Tensorflow.Util } public static object[] flatten2(ICanBeFlattened structure) - { - return structure.Flatten(); - } + => structure.Flatten(); + + public static T[] flatten2(T[] structure) + => structure; private static void _flatten_recursive(T obj, List list) { @@ -423,7 +424,7 @@ namespace Tensorflow.Util /// `flat_sequence` converted to have the same recursive structure as /// `structure`. /// - public static object pack_sequence_as(object structure, IEnumerable flat_sequence) + public static object pack_sequence_as(object structure, IEnumerable flat_sequence, bool expand_composites = false) { List flat = null; if (flat_sequence is List) From 8243807ede9c5b40d924bda015cf8084f7a500e7 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 24 Oct 2019 23:25:53 -0500 Subject: [PATCH 17/41] IsLoopConstantEnter --- .../{Util => Interfaces}/IFlatten.cs | 2 +- .../Interfaces/IPackable.cs | 11 ++ .../ControlFlows/ControlFlowContext.cs | 36 ++++- .../Operations/ControlFlows/LoopVar.cs | 14 +- .../Operations/ControlFlows/WhileContext.cs | 149 ++++++++++++++++-- .../NnOps/BodyItemInRnnWhileLoop.cs | 9 +- .../Operations/NnOps/rnn.cs | 7 +- .../Operations/_GraphTensorArray.cs | 15 ++ .../Operations/control_flow_ops.cs | 3 +- .../Operations/control_flow_util.py.cs | 20 +++ .../Operations/gen_data_flow_ops.cs | 22 +++ src/TensorFlowNET.Core/Tensors/TensorArray.cs | 3 + src/TensorFlowNET.Core/Util/nest.py.cs | 7 + 13 files changed, 272 insertions(+), 26 deletions(-) rename src/TensorFlowNET.Core/{Util => Interfaces}/IFlatten.cs (82%) create mode 100644 src/TensorFlowNET.Core/Interfaces/IPackable.cs diff --git a/src/TensorFlowNET.Core/Util/IFlatten.cs b/src/TensorFlowNET.Core/Interfaces/IFlatten.cs similarity index 82% rename from src/TensorFlowNET.Core/Util/IFlatten.cs rename to src/TensorFlowNET.Core/Interfaces/IFlatten.cs index 305dc72e..e7b076e9 100644 --- a/src/TensorFlowNET.Core/Util/IFlatten.cs +++ b/src/TensorFlowNET.Core/Interfaces/IFlatten.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Text; -namespace Tensorflow.Operations +namespace Tensorflow { public interface ICanBeFlattened { diff --git a/src/TensorFlowNET.Core/Interfaces/IPackable.cs b/src/TensorFlowNET.Core/Interfaces/IPackable.cs new file mode 100644 index 00000000..86ceabc7 --- /dev/null +++ b/src/TensorFlowNET.Core/Interfaces/IPackable.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public interface IPackable + { + void Pack(object[] sequences); + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 8a624df2..953fd6c7 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -170,7 +170,7 @@ namespace Tensorflow.Operations /// /// Add `op` to the current context. /// - public void AddOp(Operation op) + public virtual void AddOp(Operation op) { _AddOpInternal(op); } @@ -210,11 +210,6 @@ namespace Tensorflow.Operations /// protected virtual void _AddOpInternal(Operation op) { - if (op.name == "rnn/while/Less") - { - - } - if(op == null) { throw new NotImplementedException(""); @@ -255,9 +250,34 @@ namespace Tensorflow.Operations throw new NotImplementedException("_IsInOuterContext"); } - protected virtual void _RemoveExternalControlEdges(Operation op) + /// + /// Remove any external control dependency on this op. + /// + /// + protected virtual (Operation[], Operation[]) _RemoveExternalControlEdges(Operation op) { - var internal_control_inputs = op.control_inputs; + var while_ctxt = GetWhileContext(); + + var internal_control_inputs = new List(); + // A control input of `op` is internal if it is in the same while + // loop context as the enclosing while loop context of self. + if (while_ctxt == null) + { + internal_control_inputs = op.control_inputs.ToList(); + } + else + { + foreach(Tensor x in op.control_inputs) + { + throw new NotImplementedException(""); + } + } + + var external_control_inputs = new List(); + if (len(internal_control_inputs) != len(op.control_inputs)) + throw new NotImplementedException(""); + + return (internal_control_inputs.ToArray(), external_control_inputs.ToArray()); } /// diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs index d49d5abf..845ff494 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs @@ -1,13 +1,14 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow.Operations { - internal class LoopVar : ICanBeFlattened + internal class LoopVar : ICanBeFlattened, IPackable { - public Tensor Counter { get; } - public TItem Item { get; } + public Tensor Counter { get; set; } + public TItem Item { get; set; } public LoopVar(Tensor counter, TItem item) { @@ -25,6 +26,13 @@ namespace Tensorflow.Operations return elements.ToArray(); } + public void Pack(object[] sequences) + { + Counter = sequences[0] as Tensor; + if (typeof(TItem).GetInterface(typeof(IPackable).Name) != null) + (Item as IPackable).Pack(sequences.Skip(1).ToArray()); + } + public static implicit operator (Tensor, TItem)(LoopVar loopVar) { return (loopVar.Counter, loopVar.Item); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index b40dae11..55802fe7 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -240,10 +240,13 @@ namespace Tensorflow.Operations // Build the graph for body. var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); + _pivot_for_body = vars_for_body[0]; // Convert TensorArray flow variables inside the context back into // their associated TensorArrays for calling the body. - var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); - var body_result = body(original_loop_vars); + var vars_for_body_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); + var packed_vars_for_body = nest.pack_sequence_as2(original_loop_vars, vars_for_body_with_tensor_arrays); + var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + var body_result = body(packed_vars_for_body); var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); // Store body_result to keep track of TensorArrays returned by body @@ -267,17 +270,27 @@ namespace Tensorflow.Operations private void _FixControlInputsAndContext(Tensor[] enters) { var graph = ops.get_default_graph(); - foreach(var e in enters) + foreach(var x in enters) { - var inp_op = e.op.inputs[0].op; + var inp_op = x.op.inputs[0].op; var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); + var outer_control_inputs = new List(); + foreach(Operation op in control_inputs) + { + // We need to keep control inputs that are in any ancestor + // ControlFlowContext, and within outer WhileContext. + var keep_as_control_input = true; + var op_ctxt = control_flow_util.GetOutputContext(op); + var outer_ctxt = outer_context; + throw new NotImplementedException(""); + } // op for op in control_inputs if self._IsInOuterContext(op) - var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) + /*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) .Select(x => x.op) - .ToArray(); - e.op._set_control_flow_context(this); - e.op._add_control_inputs(outer_control_inputs); - graph._record_op_seen_by_control_dependencies(e.op); + .ToArray();*/ + x.op._set_control_flow_context(this); + x.op._add_control_inputs(outer_control_inputs.ToArray()); + graph._record_op_seen_by_control_dependencies(x.op); } } @@ -288,13 +301,127 @@ namespace Tensorflow.Operations _values.Add(x.name); } + protected override void _AddOpInternal(Operation op) + { + Operation[] external_inputs = new Operation[0]; + if (op == null) + { + throw new NotImplementedException(""); + } + else + { + foreach (var index in range(len(op.inputs))) + { + var x = op.inputs[index]; + var real_x = AddValue(x); + if (real_x != x) + op._update_input(index, real_x); + } + + // Remove any external control dependency on this op. + (_, external_inputs) = _RemoveExternalControlEdges(op); + // Add a control dependency to prevent loop invariants from + // enabling ops that should not be executed. + _MaybeAddControlDependency(op); + foreach (Tensor x in op.outputs) + _values.Add(x.name); + } + + if (external_inputs.Length > 0) + { + throw new NotImplementedException("external_inputs.Length > 0"); + } + + if (_outer_context != null || !IsLoopExit(op)) + foreach (Tensor x in op.outputs) + op.graph.prevent_feeding(x); + + if (_outer_context != null) + _outer_context.AddInnerOp(op); + } + + protected void _MaybeAddControlDependency(Operation op) + { + // Determines if `op` needs a control dependency. + Func _IsOpFree = (op1) => + { + if (op1.control_inputs.Length > 0) + return false; + + if (op1.type == "SymbolicGradient") + return true; + + foreach (Tensor x in op1.inputs) + if (!control_flow_util.IsLoopConstantEnter(x.op)) + return false; + + return true; + }; + + if (_IsOpFree(op)) + op._add_control_input(GetControlPivot().op); + } + + private Tensor GetControlPivot() + { + if (_pivot_for_body != null) + return _pivot_for_body; + return _pivot_for_pred; + } + + public override void AddOp(Operation op) + { + _AddOpInternal(op); + } + public override Tensor AddValue(Tensor val) { var result = val; - var new_value = _values.Contains(val.name); + var new_value = !_values.Contains(val.name); new_value &= val.op._get_control_flow_context() != this; if (new_value) - throw new NotImplementedException(""); + { + _values.Add(val.name); + + // If we are in a grad context and val is from its forward context, + // use GetRealValue(), which adds the logic to save the history of + // val in forward. + var grad_ctxt = ops.get_default_graph()._get_control_flow_context(); + if(grad_ctxt != null) + { + grad_ctxt = grad_ctxt.GetWhileContext(); + if (grad_ctxt.grad_state != null) + { + throw new NotImplementedException(""); + } + } + + if (_outer_context != null) + { + result = _outer_context.AddValue(val); + } + + // Create an Enter to make `result` known to this loop context. + Tensor enter = null; + tf_with(ops.control_dependencies(new ITensorOrOperation[0]), delegate + { + enter = _Enter( + result, + _name, + is_constant: true, + parallel_iterations: _parallel_iterations); + enter.graph.prevent_feeding(enter); + if (_outer_context != null) + _outer_context.AddInnerOp(enter.op); + }); + + // Fix the control inputs and control flow context of these enter ops. + _FixControlInputsAndContext(new[] { enter }); + // Add `enter` in this context. + _values.Add(enter.name); + _external_values[val.name] = enter; + result = enter; + } else { var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs index 9ffea25c..acc40a2d 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow.Operations { - internal class BodyItemInRnnWhileLoop : ICanBeFlattened + internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable { /// /// int32 scalar Tensor. @@ -36,5 +36,12 @@ namespace Tensorflow.Operations elements.Add(state); return elements.ToArray(); } + + public void Pack(object[] sequences) + { + time = sequences[0] as Tensor; + output_ta_t = new[] { sequences[1] as TensorArray }; + state = sequences[2] as Tensor; + } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 475dd0ff..41516bb8 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -192,7 +192,12 @@ namespace Tensorflow.Operations // Take a time step of the dynamic RNN. Func _time_step = (item) => { - throw new NotImplementedException(""); + if (in_graph_mode) + { + input_ta.Select(ta => ta.read(time)).ToArray(); + } + + return item; }; control_flow_ops.while_loop( diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index ebc88230..56ac277e 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -159,5 +159,20 @@ namespace Tensorflow.Operations { _colocate_with.Add(value); } + + public Tensor read(Tensor index, string name = null) + { + var value = gen_data_flow_ops.tensor_array_read_v3( + handle: _handle, + index: index, + flow_in: _flow, + dtype: _dtype, + name: name); + + if (_element_shape != null) + value.set_shape(_element_shape[0].dims); + + return value; + } } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 6c286fc1..1229c6b7 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -648,7 +648,8 @@ namespace Tensorflow body_buildloop = (item) => { var (i, lv) = (item.Counter, item.Item); - return new LoopVar(i + 1, orig_body(lv)); + var ob = orig_body(lv); + return new LoopVar(i + 1, ob); }; } try_to_pack = false; diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs index 5377eb5b..9dcfb2e1 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs @@ -30,6 +30,26 @@ namespace Tensorflow public static bool IsLoopExit(Operation op) { return op.type == "Exit" || op.type == "RefExit"; + } + + /// + /// Returns true if `op` is an Enter. + /// + /// + /// + public static bool IsLoopEnter(Operation op) + { + return op.type == "Enter" || op.type == "RefEnter"; + } + + /// + /// Return true iff op is a loop invariant. + /// + /// + /// + public static bool IsLoopConstantEnter(Operation op) + { + return IsLoopEnter(op) && op.get_attr("is_constant"); } /// diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index 71e9bbab..1d3cc047 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -198,5 +198,27 @@ namespace Tensorflow return _op.outputs; } + + /// + /// Read an element from the TensorArray into output `value`. + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_array_read_v3(Tensor handle, Tensor index, Tensor flow_in, TF_DataType dtype, string name = null) + { + var _op = _op_def_lib._apply_op_helper("TensorArrayReadV3", name, new + { + handle, + index, + flow_in, + dtype + }); + + return _op.output; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/TensorArray.cs b/src/TensorFlowNET.Core/Tensors/TensorArray.cs index f84072a8..fe9e2d6d 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorArray.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorArray.cs @@ -58,5 +58,8 @@ namespace Tensorflow public TensorArray unstack(Tensor value, string name = null) => _implementation.unstack(value, name: name); + + public Tensor read(Tensor index, string name = null) + => _implementation.read(index, name: name); } } diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 28f9ba03..97980203 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -401,6 +401,13 @@ namespace Tensorflow.Util private static int len(IEnumerable x) => x.Count(); + public static T pack_sequence_as2(T structure, object[] flat_sequence, bool expand_composites = false) + where T : IPackable + { + structure.Pack(flat_sequence); + return structure; + } + /// /// Returns a given flattened sequence packed into a given structure. /// If `structure` is a scalar, `flat_sequence` must be a single-element list; From d1e1e05546f883f245be3c869cad383753aac790 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 27 Oct 2019 10:46:49 -0500 Subject: [PATCH 18/41] inputs for rnn/while/TensorArrayReadV3 are incorrect #433 --- .../Operations/NnOps/rnn.cs | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 41516bb8..41a4622a 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using System; using System.Collections.Generic; using System.Linq; @@ -24,7 +25,7 @@ namespace Tensorflow.Operations { internal class rnn { - public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor, + public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, Tensor sequence_length = null, Tensor initial_state = null, TF_DataType dtype = TF_DataType.DtInvalid, int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) @@ -79,7 +80,7 @@ namespace Tensorflow.Operations /// /// /// - private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state, + private static (Tensor, Tensor) _dynamic_rnn_loop(RnnCell cell, Tensor inputs, Tensor initial_state, int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid) { var state = initial_state; @@ -170,11 +171,11 @@ namespace Tensorflow.Operations flat_input_i.dtype)); } - for (int i = 0; i < input_ta.Count; i++) + input_ta = zip(input_ta, flat_input).Select(x => { - var (ta, input_) = (input_ta[i], flat_input[i]); - ta.unstack(input_); - } + var (ta, input_) = (x.Item1, x.Item2); + return ta.unstack(input_); + }).ToList(); } // Make sure that we run at least 1 step, if necessary, to ensure @@ -192,11 +193,29 @@ namespace Tensorflow.Operations // Take a time step of the dynamic RNN. Func _time_step = (item) => { + Tensor[] input_t = null; + var (time1, output_ta_t, state1) = (item.time, item.output_ta_t, item.state); if (in_graph_mode) { - input_ta.Select(ta => ta.read(time)).ToArray(); + input_t = input_ta.Select(ta => ta.read(time1)).ToArray(); + // Restore some shape information + foreach (var (input_, shape) in zip(input_t, inputs_got_shape)) + input_.set_shape(shape[new Slice(1)]); + } + else + { + // input_t = tuple(ta[time.numpy()] for ta in input_ta) } + var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t); + // Keras RNN cells only accept state as list, even if it's a single tensor. + // var is_keras_rnn_cell = _is_keras_rnn_cell(cell); + (Tensor, Tensor) a = (null, null); + if (sequence_length != null) + throw new NotImplementedException("sequence_length != null"); + else + a = cell.__call__(input_t_t, state1); + return item; }; From 47ca86a9adea215cea8efe4bd5e7aafae4922df9 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 27 Oct 2019 10:47:55 -0500 Subject: [PATCH 19/41] rename RNNCell to RnnCell --- src/TensorFlowNET.Core/Operations/RNNCell.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/RNNCell.cs b/src/TensorFlowNET.Core/Operations/RNNCell.cs index 442115c0..9902cd41 100644 --- a/src/TensorFlowNET.Core/Operations/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/RNNCell.cs @@ -42,7 +42,7 @@ namespace Tensorflow /// matching structure of Tensors having shape `[batch_size].concatenate(s)` /// for each `s` in `self.batch_size`. /// - public abstract class RNNCell : Layers.Layer + public abstract class RnnCell : Layers.Layer { /// /// Attribute that indicates whether the cell is a TF RNN cell, due the slight @@ -53,7 +53,7 @@ namespace Tensorflow public virtual int output_size { get; } - public RNNCell(bool trainable = true, + public RnnCell(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, bool? _reuse = null) : base(trainable: trainable, From 78fe9370b9dc442ecf5506c6d07747f2509f5417 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 27 Oct 2019 10:48:12 -0500 Subject: [PATCH 20/41] IPackable --- .../Operations/NnOps/BodyItemInRnnWhileLoop.cs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs index acc40a2d..1a21326d 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow.Operations { - internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable + internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable { /// /// int32 scalar Tensor. @@ -37,11 +37,13 @@ namespace Tensorflow.Operations return elements.ToArray(); } - public void Pack(object[] sequences) + public BodyItemInRnnWhileLoop Pack(object[] sequences) { time = sequences[0] as Tensor; output_ta_t = new[] { sequences[1] as TensorArray }; state = sequences[2] as Tensor; + + return new BodyItemInRnnWhileLoop(time, output_ta_t, state); } } } From 7bc249f1bbefab9f67309300edbbc2199eaef347 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 27 Oct 2019 10:48:33 -0500 Subject: [PATCH 21/41] Make Tensor packable. --- src/TensorFlowNET.Core/Tensors/Tensor.Pack.cs | 15 +++++++++++++++ src/TensorFlowNET.Core/Tensors/Tensor.cs | 2 +- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 3 +++ 3 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 src/TensorFlowNET.Core/Tensors/Tensor.Pack.cs diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Pack.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Pack.cs new file mode 100644 index 00000000..b37612c8 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Pack.cs @@ -0,0 +1,15 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class Tensor + { + public Tensor Pack(object[] sequences) + { + return sequences[0] as Tensor; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 943edaaf..99a373c4 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -39,7 +39,7 @@ namespace Tensorflow /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// [SuppressMessage("ReSharper", "ConvertToAutoProperty")] - public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike, ITensorOrTensorArray + public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike, ITensorOrTensorArray, IPackable { private readonly int _id; private readonly Operation _op; diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index bfd90a75..b3099799 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -125,6 +125,9 @@ namespace Tensorflow { get { + if (!slice.Stop.HasValue) + slice.Stop = dims.Length - slice.Start + 1; + if (slice.Start.HasValue == false || slice.Length.HasValue == false) throw new ArgumentException("Slice must has Start and Length."); From ded16ea82f492de35892d308f1accce735ad05e3 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 27 Oct 2019 13:34:59 -0500 Subject: [PATCH 22/41] BatchNormalization return tuple for call --- src/TensorFlowNET.Core/APIs/tf.layers.cs | 8 ++--- src/TensorFlowNET.Core/APIs/tf.nn.cs | 2 +- .../Graphs/Graph.Control.cs | 1 + src/TensorFlowNET.Core/Graphs/Graph.cs | 14 ++++---- .../Graphs/_ControlDependenciesController.cs | 1 + .../Interfaces/IPackable.cs | 4 +-- .../Keras/Layers/BatchNormalization.cs | 4 +-- src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 6 ++-- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 6 ++-- .../Keras/Layers/Embedding.cs | 4 +-- src/TensorFlowNET.Core/Keras/Layers/Layer.cs | 10 +++--- .../Keras/Layers/Pooling2D.cs | 4 +-- src/TensorFlowNET.Core/Layers/Layer.cs | 6 ++-- .../Operations/BasicRNNCell.cs | 33 +++++++++++++++++-- .../ControlFlows/ControlFlowContext.cs | 21 ------------ .../Operations/ControlFlows/LoopVar.cs | 12 ++++--- .../Operations/ControlFlows/WhileContext.cs | 23 ++++++++++--- .../Operations/LayerRNNCell.cs | 4 +-- .../Operations/NnOps/rnn_cell_impl.cs | 4 +-- .../Operations/OpDefLibrary.cs | 14 ++++++++ .../Operations/Operation.Control.cs | 1 + .../Operations/Operation.Input.cs | 10 +++--- .../Operations/Operation.cs | 6 ++-- .../Operations/control_flow_ops.cs | 31 +++++++++++++++++ .../Operations/gen_math_ops.cs | 2 +- src/TensorFlowNET.Core/Operations/math_ops.cs | 17 ++++++++++ src/TensorFlowNET.Core/Util/nest.py.cs | 7 ++-- src/TensorFlowNET.Core/ops.cs | 2 ++ src/TensorFlowNET.Core/ops.name_scope.cs | 3 ++ 29 files changed, 176 insertions(+), 84 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 9f989bc5..25448441 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -63,7 +63,7 @@ namespace Tensorflow trainable: trainable, name: name); - return layer.apply(inputs); + return layer.apply(inputs).Item1; } /// @@ -117,7 +117,7 @@ namespace Tensorflow trainable: trainable, name: name); - return layer.apply(inputs, training: training); + return layer.apply(inputs, training: training).Item1; } /// @@ -143,7 +143,7 @@ namespace Tensorflow data_format: data_format, name: name); - return layer.apply(inputs); + return layer.apply(inputs).Item1; } /// @@ -179,7 +179,7 @@ namespace Tensorflow kernel_initializer: kernel_initializer, trainable: trainable); - return layer.apply(inputs); + return layer.apply(inputs).Item1; } /// diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index e9805010..5b5786d1 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -76,7 +76,7 @@ namespace Tensorflow /// /// /// A pair (outputs, state) - public (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, + public (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index c97e1b6f..c6a5dee0 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -18,6 +18,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 87a1424f..c9ad6402 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -262,15 +262,11 @@ namespace Tensorflow if (string.IsNullOrEmpty(name)) name = op_type; + // If a names ends with a '/' it is a "name scope" and we use it as-is, // after removing the trailing '/'. name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); - - if (name.Contains("define_loss/bigger_box_loss/mul_13")) - { - - } var input_ops = inputs.Select(x => x.op).ToArray(); var control_inputs = _control_dependencies_for_inputs(input_ops); @@ -377,7 +373,11 @@ namespace Tensorflow /// A string to be passed to `create_op()` that will be used /// to name the operation being created. public string unique_name(string name, bool mark_as_used = true) - { + { + if (name.EndsWith("basic_r_n_n_cell")) + { + + } if (!String.IsNullOrEmpty(_name_stack)) name = _name_stack + "/" + name; // For the sake of checking for names in use, we treat names as case @@ -405,7 +405,7 @@ namespace Tensorflow // Return the new name with the original capitalization of the given name. name = $"{name}_{i-1}"; - } + } return name; } diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs index 6a75c982..55e321df 100644 --- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs +++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs @@ -16,6 +16,7 @@ using System.Collections.Generic; using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Interfaces/IPackable.cs b/src/TensorFlowNET.Core/Interfaces/IPackable.cs index 86ceabc7..94e31ece 100644 --- a/src/TensorFlowNET.Core/Interfaces/IPackable.cs +++ b/src/TensorFlowNET.Core/Interfaces/IPackable.cs @@ -4,8 +4,8 @@ using System.Text; namespace Tensorflow { - public interface IPackable + public interface IPackable { - void Pack(object[] sequences); + T Pack(object[] sequences); } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 0428b2ad..57311e8b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -139,14 +139,14 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { Tensor outputs = null; if (fused) { outputs = _fused_batch_norm(inputs, training: training); - return outputs; + return (outputs, outputs); } throw new NotImplementedException("BatchNormalization call"); diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index dc40ae8c..6a7c58cc 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) @@ -124,9 +124,9 @@ namespace Tensorflow.Keras.Layers } if (activation != null) - return activation.Activate(outputs); + outputs = activation.Activate(outputs); - return outputs; + return (outputs, outputs); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 2564da6d..212035cb 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { Tensor outputs = null; var rank = inputs.rank; @@ -88,9 +88,9 @@ namespace Tensorflow.Keras.Layers if (use_bias) outputs = tf.nn.bias_add(outputs, bias); if (activation != null) - return activation.Activate(outputs); + outputs = activation.Activate(outputs); - return outputs; + return (outputs, outputs); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index f10499c4..f15c01b8 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -50,14 +50,14 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) inputs = math_ops.cast(inputs, tf.int32); var @out = embedding_ops.embedding_lookup(embeddings, inputs); - return @out; + return (@out, @out); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 25161721..46d45862 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -101,7 +101,7 @@ namespace Tensorflow.Keras.Layers _inbound_nodes = new List(); } - public Tensor __call__(Tensor[] inputs, + public (Tensor, Tensor) __call__(Tensor[] inputs, Tensor training = null, VariableScope scope = null) { @@ -139,14 +139,14 @@ namespace Tensorflow.Keras.Layers // overridden). _maybe_build(inputs[0]); - outputs = call(inputs[0], training: training); + (input, outputs) = call(inputs[0], training: training); (input, outputs) = _set_connectivity_metadata_(input, outputs); _handle_activity_regularization(inputs[0], outputs); _set_mask_metadata(inputs[0], outputs, null); }); } - return outputs; + return (input, outputs); } private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs) @@ -173,9 +173,9 @@ namespace Tensorflow.Keras.Layers return null; } - protected virtual Tensor call(Tensor inputs, Tensor training = null) + protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { - return inputs; + return (inputs, inputs); } protected virtual string _name_scope() diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index 9774750a..e9008543 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override Tensor call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { int[] pool_shape; if (data_format == "channels_last") @@ -64,7 +64,7 @@ namespace Tensorflow.Keras.Layers padding: padding.ToUpper(), data_format: conv_utils.convert_data_format(data_format, 4)); - return outputs; + return (outputs, outputs); } } } diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 138f0fc7..2ea427c3 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -47,12 +47,12 @@ namespace Tensorflow.Layers _keras_style = false; } - public virtual Tensor apply(Tensor inputs, Tensor training = null) + public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null) { return __call__(inputs, training: training); } - public Tensor __call__(Tensor inputs, + public (Tensor, Tensor) __call__(Tensor inputs, Tensor training = null, VariableScope scope = null) { @@ -71,7 +71,7 @@ namespace Tensorflow.Layers auxiliary_name_scope: false); } - Tensor outputs = null; + (Tensor, Tensor) outputs = (null, null); tf_with(scope_context_manager, scope2 => { _current_scope = scope2; diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs index 554e9f1a..9911212b 100644 --- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs @@ -16,18 +16,23 @@ using System; using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; namespace Tensorflow { - public class BasicRNNCell : LayerRNNCell + public class BasicRnnCell : LayerRnnCell { int _num_units; Func _activation; public override int state_size => _num_units; public override int output_size => _num_units; + public VariableV1 _kernel; + string _WEIGHTS_VARIABLE_NAME = "kernel"; + public VariableV1 _bias; + string _BIAS_VARIABLE_NAME = "bias"; - public BasicRNNCell(int num_units, + public BasicRnnCell(int num_units, Func activation = null, bool? reuse = null, string name = null, @@ -44,5 +49,29 @@ namespace Tensorflow else _activation = activation; } + + protected override void build(TensorShape inputs_shape) + { + var input_depth = inputs_shape.dims[inputs_shape.ndim - 1]; + + _kernel = add_weight( + _WEIGHTS_VARIABLE_NAME, + shape: new[] { input_depth + _num_units, _num_units }); + + _bias = add_weight( + _BIAS_VARIABLE_NAME, + shape: new[] { _num_units }, + initializer: tf.zeros_initializer); + + built = true; + } + + protected override (Tensor, Tensor) call(Tensor inputs, Tensor state = null) + { + // Most basic RNN: output = new_state = act(W * input + U * state + B). + var concat = array_ops.concat(new[] { inputs, state }, 1); + var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable); + return (inputs, inputs); + } } } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 953fd6c7..97f244c4 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -136,27 +136,6 @@ namespace Tensorflow.Operations graph._set_control_flow_context(this); } - protected virtual Tensor _Enter(Tensor data, string frame_name, - bool is_constant = false, - int parallel_iterations = 10, - bool use_ref = true, - bool use_input_shape = true, - string name = null) - { - Tensor result; - data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); - if (data.dtype.is_ref_dtype() && use_ref) - throw new NotImplementedException("_Enter"); - else - result = gen_control_flow_ops.enter( - data, frame_name, is_constant, parallel_iterations, name: name); - - if (use_input_shape) - result.set_shape(data.TensorShape); - - return result; - } - /// /// Exit this control flow context. /// diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs index 845ff494..5359190c 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow.Operations { - internal class LoopVar : ICanBeFlattened, IPackable + internal class LoopVar : ICanBeFlattened, IPackable> { public Tensor Counter { get; set; } public TItem Item { get; set; } @@ -26,11 +26,13 @@ namespace Tensorflow.Operations return elements.ToArray(); } - public void Pack(object[] sequences) + public LoopVar Pack(object[] sequences) { - Counter = sequences[0] as Tensor; - if (typeof(TItem).GetInterface(typeof(IPackable).Name) != null) - (Item as IPackable).Pack(sequences.Skip(1).ToArray()); + var counter = sequences[0] as Tensor; + var item = default(TItem); + if (typeof(TItem).GetInterface(typeof(IPackable).Name) != null) + item = (Item as IPackable).Pack(sequences.Skip(1).ToArray()); + return new LoopVar(counter, item); } public static implicit operator (Tensor, TItem)(LoopVar loopVar) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 55802fe7..1d071856 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -184,7 +184,7 @@ namespace Tensorflow.Operations Tensor[] enter_vars = null; tf_with(ops.control_dependencies(null), delegate { - enter_vars = real_vars.Select(x => _Enter(x, + enter_vars = real_vars.Select(x => control_flow_ops._Enter(x, _name, is_constant: false, parallel_iterations: _parallel_iterations, @@ -294,6 +294,10 @@ namespace Tensorflow.Operations } } + /// + /// Makes the values known to this context. + /// + /// private void _InitializeValues(Tensor[] values) { _values = new HashSet(); @@ -303,8 +307,14 @@ namespace Tensorflow.Operations protected override void _AddOpInternal(Operation op) { + if(op.name == "rnn/while/basic_rnn_cell/MatMul" || + op.name == "rnn/while/TensorArrayReadV3") + { + + } + Operation[] external_inputs = new Operation[0]; - if (op == null) + if (op.inputs.Length == 0) { throw new NotImplementedException(""); } @@ -374,6 +384,11 @@ namespace Tensorflow.Operations _AddOpInternal(op); } + /// + /// Add `val` to the current context and its outer context recursively. + /// + /// + /// public override Tensor AddValue(Tensor val) { var result = val; @@ -403,9 +418,9 @@ namespace Tensorflow.Operations // Create an Enter to make `result` known to this loop context. Tensor enter = null; - tf_with(ops.control_dependencies(new ITensorOrOperation[0]), delegate + tf_with(ops.control_dependencies(null), delegate { - enter = _Enter( + enter = control_flow_ops._Enter( result, _name, is_constant: true, diff --git a/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs b/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs index ca9c31bb..16aa147c 100644 --- a/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs @@ -16,9 +16,9 @@ namespace Tensorflow { - public class LayerRNNCell : RNNCell + public class LayerRnnCell : RnnCell { - public LayerRNNCell(bool? _reuse = null, + public LayerRnnCell(bool? _reuse = null, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse, name: name, diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs index c3d9cbdf..3164ba14 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs @@ -20,8 +20,8 @@ namespace Tensorflow.Operations { public class rnn_cell_impl { - public BasicRNNCell BasicRNNCell(int num_units) - => new BasicRNNCell(num_units); + public BasicRnnCell BasicRNNCell(int num_units) + => new BasicRnnCell(num_units); public static Tensor _concat(Tensor prefix, int suffix, bool @static = false) { diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 89ddebdb..5700ccdd 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -228,6 +228,15 @@ namespace Tensorflow output_types.AddRange(types); } + // We add an explicit colocation constraint between + // the newly created op and any of its reference-typed inputs. + var must_colocate_inputs = zip(op_def.InputArg, inputs) + .Where(x => x.Item1.IsRef) + .Select(x => x.Item2) + .ToArray(); + + _MaybeColocateWith(must_colocate_inputs); + // Add Op to graph var op = g.create_op(op_type_name, inputs.ToArray(), @@ -241,6 +250,11 @@ namespace Tensorflow }); } + private void _MaybeColocateWith(ITensorOrOperation[] inputs) + { + + } + private void SetAttrs(string op_type_name, ArgDef input_arg, OpDef op_def, diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 5e93cfd0..8e660797 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 57ac8271..af3c57b2 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -44,14 +44,14 @@ namespace Tensorflow [JsonIgnore] #endif public int NumInputs => c_api.TF_OperationNumInputs(_handle); - private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); + private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray(); - private InputList _inputs; + private InputList _inputs_val; public InputList inputs { get { - if (_inputs == null) + if (_inputs_val == null) { var retval = new Tensor[NumInputs]; @@ -62,10 +62,10 @@ namespace Tensorflow retval[i] = op.outputs[tf_output.index]; } - _inputs = new InputList(retval); + _inputs_val = new InputList(retval); } - return _inputs; + return _inputs_val; } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index db001e51..8bdaaa7b 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -175,8 +175,8 @@ namespace Tensorflow // Dict mapping op name to file and line information for op colocation // context managers. - _control_flow_context = graph._get_control_flow_context(); - + _control_flow_context = graph._get_control_flow_context(); + // This will be set by self.inputs. if (op_def == null) op_def = g.GetOpDef(node_def.Op); @@ -305,7 +305,7 @@ namespace Tensorflow var output = tensor._as_tf_output(); // Reset cached inputs. - _inputs = null; + _inputs_val = null; // after the c_api call next time _inputs is accessed // the updated inputs are reloaded from the c_api lock (Locks.ProcessWide) diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 1229c6b7..13182dfd 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -675,5 +675,36 @@ namespace Tensorflow throw new NotImplementedException("while_loop"); } + /// + /// Creates or finds a child frame, and makes `data` available to it. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor _Enter(Tensor data, string frame_name, + bool is_constant = false, + int parallel_iterations = 10, + bool use_ref = true, + bool use_input_shape = true, + string name = null) + { + Tensor result; + data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); + if (data.dtype.is_ref_dtype() && use_ref) + throw new NotImplementedException("_Enter"); + else + result = gen_control_flow_ops.enter( + data, frame_name, is_constant, parallel_iterations, name: name); + + if (use_input_shape) + result.set_shape(data.TensorShape); + + return result; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index e1225cc9..08431089 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -568,7 +568,7 @@ namespace Tensorflow { var _op = _op_def_lib._apply_op_helper("MatMul", name, args: new { a, b, transpose_a, transpose_b }); - return _op.outputs[0]; + return _op.output; } /// diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index d4dfc12b..17cd8a99 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -543,6 +543,23 @@ namespace Tensorflow public static Tensor maximum(Tx x, Ty y, string name = null) => gen_math_ops.maximum(x, y, name: name); + /// + /// Multiplies matrix `a` by matrix `b`, producing `a` * `b`. + /// + /// + /// + /// If `True`, `a` is transposed before multiplication. + /// If `True`, `b` is transposed before multiplication. + /// If `True`, `a` is conjugated and transposed before multiplication. + /// If `True`, `b` is conjugated and transposed before multiplication. + /// If `True`, `a` is treated as a sparse matrix. + /// If `True`, `b` is treated as a sparse matrix. + /// Name for the operation (optional). + /// + /// A `Tensor` of the same type as `a` and `b` where each inner-most matrix is + /// the product of the corresponding matrices in `a` and `b`, e.g. if all + /// transpose or adjoint attributes are `False`: + /// public static Tensor matmul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, bool adjoint_a = false, bool adjoint_b = false, diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 97980203..54149fe1 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -402,11 +402,8 @@ namespace Tensorflow.Util private static int len(IEnumerable x) => x.Count(); public static T pack_sequence_as2(T structure, object[] flat_sequence, bool expand_composites = false) - where T : IPackable - { - structure.Pack(flat_sequence); - return structure; - } + where T : IPackable + => structure.Pack(flat_sequence); /// /// Returns a given flattened sequence packed into a given structure. diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index d1e423c9..3549b07e 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -508,6 +508,8 @@ namespace Tensorflow return null; case TensorShape ts: return constant_op.constant(ts.dims, dtype: dtype, name: name); + case int[] dims: + return constant_op.constant(dims, dtype: dtype, name: name); case object[] objects: return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name); default: diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index bd98f2ca..80397667 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -45,7 +45,10 @@ namespace Tensorflow public void __enter__() { _name = _name ?? _default_name; + if (_name.EndsWith("basic_r_n_n_cell")) + { + } Graph g = null; if (_values is List vList) From 5ee46e494a48af03832df2ce03c41692a836bee6 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 27 Oct 2019 22:03:35 -0500 Subject: [PATCH 23/41] tf.while_loop #348 --- src/TensorFlowNET.Core/Device/c_api.device.cs | 32 ++++++ .../Graphs/_ControlDependenciesController.cs | 4 +- .../Keras/Layers/BatchNormalization.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 2 +- .../Keras/Layers/Embedding.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/Layer.cs | 14 ++- .../Keras/Layers/Pooling2D.cs | 2 +- src/TensorFlowNET.Core/Layers/Layer.cs | 11 ++- .../Operations/BasicRNNCell.cs | 2 +- .../Operations/ControlFlows/WhileContext.cs | 12 +-- .../Initializers/VarianceScaling.cs | 1 + .../Operations/NnOps/rnn.cs | 2 +- .../Operations/Operation.Control.cs | 2 - .../Operations/Operation.cs | 5 + .../Operations/random_ops.py.cs | 2 +- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 5 + .../Variables/RefVariable.cs | 99 ++++++++++--------- src/TensorFlowNET.Core/ops.cs | 29 ++++-- 19 files changed, 150 insertions(+), 80 deletions(-) create mode 100644 src/TensorFlowNET.Core/Device/c_api.device.cs diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs new file mode 100644 index 00000000..2ce79a3e --- /dev/null +++ b/src/TensorFlowNET.Core/Device/c_api.device.cs @@ -0,0 +1,32 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + public partial class c_api + { + /// + /// Specify the device for `desc`. Defaults to empty, meaning unconstrained. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_SetDevice(IntPtr desc, string device); + } +} diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs index 55e321df..63285bae 100644 --- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs +++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs @@ -69,7 +69,9 @@ namespace Tensorflow _new_stack = false; } - _seen_nodes = new List(); + _seen_nodes = new List(); + _old_stack = null; + _old_control_flow_context = null; } public void add_op(ITensorOrOperation op) diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 57311e8b..9b42eaaa 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -139,7 +139,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { Tensor outputs = null; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index 6a7c58cc..ad233d6b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 212035cb..74778873 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { Tensor outputs = null; var rank = inputs.rank; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index f15c01b8..95544d36 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -50,7 +50,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 46d45862..d7d7e31a 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -52,6 +52,7 @@ namespace Tensorflow.Keras.Layers protected InputSpec input_spec; protected bool supports_masking; protected List _trainable_weights; + protected List _non_trainable_weights; private string _name; public string name => _name; protected string _base_name; @@ -84,6 +85,7 @@ namespace Tensorflow.Keras.Layers _init_set_name(name); _trainable_weights = new List(); + _non_trainable_weights = new List(); _compute_previous_mask = false; _updates = new List(); @@ -103,6 +105,7 @@ namespace Tensorflow.Keras.Layers public (Tensor, Tensor) __call__(Tensor[] inputs, Tensor training = null, + Tensor state = null, VariableScope scope = null) { var input_list = inputs; @@ -139,7 +142,9 @@ namespace Tensorflow.Keras.Layers // overridden). _maybe_build(inputs[0]); - (input, outputs) = call(inputs[0], training: training); + (input, outputs) = call(inputs[0], + training: training, + state: state); (input, outputs) = _set_connectivity_metadata_(input, outputs); _handle_activity_regularization(inputs[0], outputs); _set_mask_metadata(inputs[0], outputs, null); @@ -173,7 +178,7 @@ namespace Tensorflow.Keras.Layers return null; } - protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { return (inputs, inputs); } @@ -233,7 +238,10 @@ namespace Tensorflow.Keras.Layers initializer: initializer, trainable: trainable.Value); //backend.track_variable(variable); - _trainable_weights.Add(variable); + if (trainable == true) + _trainable_weights.Add(variable); + else + _non_trainable_weights.Add(variable); return variable; } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index e9008543..81d57abe 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { int[] pool_shape; if (data_format == "channels_last") diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 2ea427c3..d7cda786 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -43,6 +43,7 @@ namespace Tensorflow.Layers // Avoid an incorrect lint error _trainable_weights = new List(); + _non_trainable_weights = new List(); this.built = false; _keras_style = false; } @@ -54,6 +55,7 @@ namespace Tensorflow.Layers public (Tensor, Tensor) __call__(Tensor inputs, Tensor training = null, + Tensor state = null, VariableScope scope = null) { _set_scope(scope); @@ -76,7 +78,9 @@ namespace Tensorflow.Layers { _current_scope = scope2; // Actually call layer - outputs = base.__call__(new Tensor[] { inputs }, training: training); + outputs = base.__call__(new Tensor[] { inputs }, + training: training, + state: state); }); @@ -121,6 +125,11 @@ namespace Tensorflow.Layers Graph init_graph = null; VariableV1[] existing_variables = null; + if (synchronization == VariableSynchronization.OnRead) + trainable = false; + else if (!trainable.HasValue) + trainable = true; + if (default_graph.building_function) { throw new NotImplementedException("add_weight"); diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs index 9911212b..fdcc03ea 100644 --- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs @@ -66,7 +66,7 @@ namespace Tensorflow built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor state = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { // Most basic RNN: output = new_state = act(W * input + U * state + B). var concat = array_ops.concat(new[] { inputs, state }, 1); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 1d071856..715c68c6 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -307,12 +307,6 @@ namespace Tensorflow.Operations protected override void _AddOpInternal(Operation op) { - if(op.name == "rnn/while/basic_rnn_cell/MatMul" || - op.name == "rnn/while/TensorArrayReadV3") - { - - } - Operation[] external_inputs = new Operation[0]; if (op.inputs.Length == 0) { @@ -412,10 +406,12 @@ namespace Tensorflow.Operations } if (_outer_context != null) - { result = _outer_context.AddValue(val); - } + if (tf.get_default_graph()._nodes_by_name.Count >= 83) + { + + } // Create an Enter to make `result` known to this loop context. Tensor enter = null; tf_with(ops.control_dependencies(null), delegate diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index e1ac0204..636b1451 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -16,6 +16,7 @@ using System; using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow.Operations.Initializers { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 41a4622a..a8a0e0b9 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -214,7 +214,7 @@ namespace Tensorflow.Operations if (sequence_length != null) throw new NotImplementedException("sequence_length != null"); else - a = cell.__call__(input_t_t, state1); + a = cell.__call__(input_t_t, state: state1); return item; }; diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 8e660797..9f0cb9a5 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -32,9 +32,7 @@ namespace Tensorflow public void _control_flow_post_processing() { foreach(Tensor input_tensor in inputs) - { control_flow_util.CheckInputFromValidContext(this, input_tensor.op); - } if (_control_flow_context != null) _control_flow_context.AddOp(this); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 8bdaaa7b..d5068f2e 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -78,6 +78,7 @@ namespace Tensorflow #if SERIALIZABLE [JsonIgnore] #endif + bool _is_stateful; public NodeDef node_def { get @@ -173,6 +174,8 @@ namespace Tensorflow } } + _id_value = _graph._next_id(); + // Dict mapping op name to file and line information for op colocation // context managers. _control_flow_context = graph._get_control_flow_context(); @@ -184,6 +187,8 @@ namespace Tensorflow var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); + _is_stateful = op_def.IsStateful; + // Initialize self._outputs. output_types = new TF_DataType[NumOutputs]; for (int i = 0; i < NumOutputs; i++) diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs index 9251f867..be4aef55 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs @@ -71,7 +71,7 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => { name = scope; - var tensorShape = _ShapeTensor(shape); + var tensorShape = tensor_util.shape_tensor(shape); var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); var rnd = gen_random_ops.random_uniform(tensorShape, dtype); diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 142afe06..0989db4f 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -335,5 +335,10 @@ namespace Tensorflow return shape; } + + public static Tensor shape_tensor(int[] shape) + { + return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape"); + } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 4b0a35fb..c79c5b7f 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -133,66 +133,69 @@ namespace Tensorflow if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); - ops.init_scope(); - var values = init_from_fn ? new object[0] : new object[] { initial_value }; - tf_with(ops.name_scope(name, "Variable", values), scope => + tf_with(ops.init_scope2(), delegate { - name = scope; - if (init_from_fn) + var values = init_from_fn ? new object[0] : new object[] { initial_value }; + tf_with(ops.name_scope(name, "Variable", values), scope => { - // Use attr_scope and device(None) to simulate the behavior of - // colocate_with when the variable we want to colocate with doesn't - // yet exist. - string true_name = ops.name_from_scope_name(name); - var attr = new AttrValue + name = scope; + + if (init_from_fn) { - List = new AttrValue.Types.ListValue() - }; - attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); - tf_with(ops.name_scope("Initializer"), scope2 => + // Use attr_scope and device(None) to simulate the behavior of + // colocate_with when the variable we want to colocate with doesn't + // yet exist. + string true_name = ops.name_from_scope_name(name); + var attr = new AttrValue + { + List = new AttrValue.Types.ListValue() + }; + attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); + tf_with(ops.name_scope("Initializer"), scope2 => + { + _initial_value = (initial_value as Func)(); + _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); + }); + _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); + } + // Or get the initial value from a Tensor or Python object. + else { - _initial_value = (initial_value as Func)(); - _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); - }); - _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); - } - // Or get the initial value from a Tensor or Python object. - else - { - _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); + _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); - var shape = _initial_value.shape; - dtype = _initial_value.dtype; - _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); - } + var shape = _initial_value.shape; + dtype = _initial_value.dtype; + _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); + } - // Manually overrides the variable's shape with the initial value's. - if (validate_shape) - { - var initial_value_shape = _initial_value.TensorShape; - if (!initial_value_shape.is_fully_defined()) - throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); - } + // Manually overrides the variable's shape with the initial value's. + if (validate_shape) + { + var initial_value_shape = _initial_value.TensorShape; + if (!initial_value_shape.is_fully_defined()) + throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); + } - // If 'initial_value' makes use of other variables, make sure we don't - // have an issue if these other variables aren't initialized first by - // using their initialized_value() method. - var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); + // If 'initial_value' makes use of other variables, make sure we don't + // have an issue if these other variables aren't initialized first by + // using their initialized_value() method. + var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); - _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; + _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; - if (!String.IsNullOrEmpty(caching_device)) - { + if (!String.IsNullOrEmpty(caching_device)) + { - } - else - { - ops.colocate_with(_initializer_op); + } + else + { + ops.colocate_with(_initializer_op); - _snapshot = gen_array_ops.identity(_variable, name = "read"); - } + _snapshot = gen_array_ops.identity(_variable, name = "read"); + } - ops.add_to_collections(collections, this as VariableV1); + ops.add_to_collections(collections, this as VariableV1); + }); }); } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 3549b07e..02417594 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -186,12 +186,7 @@ namespace Tensorflow /// operations constructed within the context. /// public static _ControlDependenciesController control_dependencies(object[] control_inputs) - { - return get_default_graph().control_dependencies(control_inputs); - } - - public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) - => control_dependencies(control_inputs == null ? null : control_inputs.OfType().ToArray()); + => get_default_graph().control_dependencies(control_inputs); /// /// Creates a TF_Operation. @@ -212,9 +207,9 @@ namespace Tensorflow { var op_desc = graph.NewOperation(node_def.Op, node_def.Name); - //TODO: Implement TF_SetDevice - //if node_def.device: - // c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) + if (!string.IsNullOrEmpty(node_def.Device)) + c_api.TF_SetDevice(op_desc, node_def.Device); + // Add inputs foreach (var op_input in inputs) { @@ -310,6 +305,22 @@ namespace Tensorflow }); } + public static IObjectLife init_scope2() + { + // Retrieve the active name scope: entering an `init_scope` preserves + // the name scope of the current context. + var default_graph = get_default_graph(); + var scope = default_graph.get_name_scope(); + if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) + // Names that end with trailing slashes are treated by `name_scope` as + // absolute. + scope += "/"; + // inner_device_stack = default_graph._device_function_stack + // var outer_context = default_graph.as_default; + + return ops.control_dependencies(null); + } + private static int uid_number = 0; /// From e2190c94fc94533f90208027b07825bc1be2f1ea Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 08:01:43 -0500 Subject: [PATCH 24/41] ControlFlow MergeOutput --- .../ControlFlows/ControlFlowContext.cs | 28 +- .../ControlFlows/ControlFlowState.cs | 266 ++++++++++------- .../Operations/ControlFlows/GradLoopState.cs | 270 ++++++------------ .../Operations/ControlFlows/MergeOutput.cs | 36 +++ .../Operations/ControlFlows/WhileContext.cs | 168 +++++++++-- 5 files changed, 459 insertions(+), 309 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 97f244c4..00c395a3 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -20,6 +20,7 @@ using System.Linq; using Tensorflow.Operations.ControlFlows; using static Tensorflow.ControlFlowContextDef; using static Tensorflow.Binding; +using util = Tensorflow.control_flow_util; namespace Tensorflow.Operations { @@ -146,6 +147,14 @@ namespace Tensorflow.Operations graph._set_control_flow_context(last_context); } + public void ExitResult(Tensor[] result) + { + if(_outer_context != null) + { + throw new NotImplementedException("ExitResult"); + } + } + /// /// Add `op` to the current context. /// @@ -172,6 +181,11 @@ namespace Tensorflow.Operations return null; } + public void AddName(string name) + { + _values.Add(name); + } + /// /// Notifies a scope about an operator added to an inner scope. /// @@ -246,9 +260,11 @@ namespace Tensorflow.Operations } else { - foreach(Tensor x in op.control_inputs) + foreach(Operation x in op.control_inputs) { - throw new NotImplementedException(""); + var ctxt = util.GetOutputContext(x); + if (ctxt != null && ctxt.GetWhileContext() == while_ctxt) + internal_control_inputs.append(x); } } @@ -288,6 +304,14 @@ namespace Tensorflow.Operations throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}"); } + public virtual bool IsWhileContext() + { + throw new NotImplementedException("IsWhileContext"); + } + + public virtual bool IsCondContext() + => false; + public object to_proto() { throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs index fd41c045..d1be6f31 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs @@ -14,6 +14,12 @@ limitations under the License. ******************************************************************************/ +using System; +using System.Linq; +using System.Collections.Generic; +using util = Tensorflow.control_flow_util; +using static Tensorflow.Binding; + namespace Tensorflow.Operations.ControlFlows { /// @@ -21,6 +27,7 @@ namespace Tensorflow.Operations.ControlFlows /// public class ControlFlowState { + Dictionary _map; //class ControlFlowState(object): // """Maintain the mapping from the loops to their grad states.""" @@ -40,51 +47,67 @@ namespace Tensorflow.Operations.ControlFlows // return self._map.get(forward_ctxt) // return None - // def ProcessUnusedLoopExits(self, pending_count, to_ops_set): - // """Process all the "unused" loop exits. - - // The "unused" exits of the loops are added to `unused_exits`. An exit is - // unused if its pending_count is 0. If there is an exit with real gradient, - // all these deferred exits will enter the backprop loop with zero gradient. - // Otherwise, they will enter the backprop loop with None. As an example, - // people often write: - - // ```python - // v1, _ = tf.while_loop(p, b, [x1, x2]) - // result = gradients(v1, x1) - // ``` - - // The exit node for x2 is not included by the betweenness analysis. But we - // need to backprop x2 if x2 is involved in computing v1. - - // Args: - // pending_count: The number of backprop inputs for every op. - // to_ops_set: The set of ops for ys in gradients(ys, xs) - - // Returns: - // The set of unused loop exits that we know at this point we need - // to backprop. - // """ - // loop_exits = [] - // for grad_state in self._map.values(): - // for y in grad_state.forward_loop_exits: - // if pending_count[y.op] == 0: - // grad_state.pending_exits_count -= 1 - // if y.op not in to_ops_set: - // grad_state.unused_exits.append(y) - // if grad_state.pending_exits_count == 0: - // loop_exits.extend(grad_state.unused_exits) - // # Need to include Enters in backprop for higher-order gradients. - // for y in grad_state.forward_context.loop_enters: - // if pending_count[y.op] == 0: - // pending_count[y.op] = 1 - // return loop_exits - - // def EnterGradWhileContext(self, op, before): - // """Enter the WhileContext for gradient computation.""" - // grad_state = self.GetGradState(op, before) - // if grad_state: - // grad_state.grad_context.Enter() + public ControlFlowState() + { + _map = new Dictionary(); + } + + /// + /// Return the grad state for this op if it's in a forward loop context. + /// + /// + /// + /// + public GradLoopState GetGradState(Operation op, bool before) + { + ControlFlowContext forward_ctxt = null; + if (before && util.IsLoopExit(op)) + { + forward_ctxt = op._get_control_flow_context(); + forward_ctxt = forward_ctxt.outer_context; + if (forward_ctxt != null) + forward_ctxt = forward_ctxt.GetWhileContext(); + } + else + forward_ctxt = util.GetWhileContext(op); + if (forward_ctxt != null) + return _map.get(forward_ctxt); + return null; + } + + public Tensor[] ProcessUnusedLoopExits(Dictionary pending_count, List to_ops_set) + { + var loop_exits = new List(); + foreach(var grad_state in _map.Values) + { + foreach(var y in grad_state.forward_loop_exits) + { + if(!pending_count.ContainsKey(y.op.name)) + { + grad_state.pending_exits_count -= 1; + if (!to_ops_set.Contains(y.op)) + grad_state.unused_exits.append(y); + if (grad_state.pending_exits_count == 0) + loop_exits.extend(grad_state.unused_exits); + } + } + + foreach(var y in grad_state.forward_context.loop_enters) + { + if (!pending_count.ContainsKey(y.op.name)) + pending_count[y.op.name] = 1; + } + } + + return loop_exits.ToArray(); + } + + public void EnterGradWhileContext(Operation op, bool before) + { + var grad_state = GetGradState(op, before); + if (grad_state != null) + grad_state.grad_context.Enter(); + } // def ExitGradWhileContext(self, op, before): // """Exit the WhileContext for gradient computation.""" @@ -118,6 +141,32 @@ namespace Tensorflow.Operations.ControlFlows // if loop_exit.op not in between_ops: // between_ops.add(loop_exit.op) // between_op_list.append(loop_exit.op) + public void AddWhileContext(Operation op, List between_op_list, List between_ops) + { + var forward_ctxt = op.GetWhileContext(); + var grad_state = _map.ContainsKey(forward_ctxt) ? _map[forward_ctxt] : null; + if(grad_state == null) + { + GradLoopState outer_grad_state = null; + var outer_forward_ctxt = forward_ctxt.outer_context; + if (outer_forward_ctxt != null) + outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); + if (outer_forward_ctxt != null) + outer_grad_state = _map[outer_forward_ctxt]; + grad_state = new GradLoopState(forward_ctxt, outer_grad_state); + _map[forward_ctxt] = grad_state; + + // We need to include all exits of a loop for backprop. + foreach (var loop_exit in grad_state.forward_loop_exits) + { + if(!between_ops.Contains(loop_exit.op)) + { + between_ops.add(loop_exit.op); + between_op_list.append(loop_exit.op); + } + } + } + } // def ZerosLikeForExit(self, val): // """Create zeros_like gradient for a loop exit. @@ -174,70 +223,69 @@ namespace Tensorflow.Operations.ControlFlows // result = array_ops.zeros_like(val, optimize=False) // return result - // def ZerosLike(self, op, index): - // """Create zeros_like for the specified output of an op. - - // If op is in a while loop that is part of gradients(), this method - // must be called in its grad loop context. - - // Args: - // op: A tensorflow operation. - // index: the index for a specific output of the op. - - // Returns: - // A zero tensor of the same shape of op.outputs[index]. - // """ - // if util.IsLoopSwitch(op): - // return None - // if op.graph._building_function: # pylint: disable=protected-access - // # The optimization here is tricky to apply to functions - // return array_ops.zeros_like(op.outputs[index]) - // dead_branch = util.IsSwitch(op) - // forward_ctxt = _GetWhileContext(op) - // grad_state = self._map.get(forward_ctxt) - // if grad_state is None: - // # op is not in a while loop that is part of gradients(). - // return ZerosLikeOutsideLoop(op, index) - // op_ctxt = op._get_control_flow_context() - // val = ops.convert_to_tensor(op.outputs[index], name="tensor") - // shape = val.get_shape() - // if shape.is_fully_defined(): - // # If the shape is known statically, just create a zero tensor with - // # the right shape in the grad loop context. - // result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) - // if dead_branch: - // # op is a cond switch. Guard the zero tensor with a switch. - // pred = grad_state.history_map.get(op_ctxt.pred.name) - // branch = op_ctxt.branch - // result = _SwitchRefOrTensor(result, pred)[1 - branch] - // else: - // # Unknown shape so keep a history of the shape at runtime. - // if dead_branch: - // # Need to add a special switch to guard the value. - // pred = op_ctxt.pred - // branch = op_ctxt.branch - // op_ctxt.outer_context.Enter() - // val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] - // zeros_shape = array_ops.shape_internal(val, optimize=False) - // op_ctxt.outer_context.Exit() - // val.op._set_control_flow_context(op_ctxt) - // zeros_shape.op._set_control_flow_context(op_ctxt) - // else: - // op_ctxt.Enter() - // zeros_shape = array_ops.shape_internal(val, optimize=False) - // op_ctxt.Exit() - - // # Add forward accumulator for shape. - // grad_state.grad_context.Exit() - // history_zeros_shape = grad_state.AddForwardAccumulator( - // zeros_shape, dead_branch=dead_branch) - // grad_state.grad_context.Enter() - - // # Create a zero tensor with the right shape. - // shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, - // zeros_shape, dead_branch) - // result = array_ops.zeros(shape, val.dtype) - // return result + public Tensor ZerosLike(Operation op, int index) + { + if (util.IsLoopSwitch(op)) + return null; + if (op.graph.building_function) + return array_ops.zeros_like(op.outputs[index]); + var dead_branch = util.IsSwitch(op); + var forward_ctxt = util.GetWhileContext(op); + var grad_state = _map.get(forward_ctxt); + // op is not in a while loop that is part of gradients(). + if (grad_state == null) + return ZerosLikeOutsideLoop(op, index); + throw new NotImplementedException("ZerosLike"); + } + + public Tensor ZerosLikeOutsideLoop(Operation op, int index) + { + var val = op.outputs[index]; + if (!util.IsSwitch(op)) + { + if (val.dtype == dtypes.resource) + throw new NotImplementedException("ZerosLikeOutsideLoop"); + /*return array_ops.zeros( + gen_resource_variable_ops.variable_shape(val), + dtype: default_gradient.get_zeros_dtype(val));*/ + return array_ops.zeros_like(val, optimize: false); + } + else + throw new NotImplementedException("ZerosLikeOutsideLoop"); + } + + /// + /// Create zeros_like gradient for a loop exit. + /// + /// + /// + public Tensor ZerosLikeForExit(Tensor val) + { + Tensor result = null; + var val_shape = val.TensorShape; + var forward_ctxt = val.op._get_control_flow_context(); + var outer_forward_ctxt = forward_ctxt.outer_context; + if (outer_forward_ctxt != null) + outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); + GradLoopState outer_grad_state = null; + if (outer_forward_ctxt != null) + outer_grad_state = _map.get(outer_forward_ctxt); + // This is a nested loop. + if (outer_grad_state != null) + { + throw new NotImplementedException("ZerosLikeForExit"); + } + else + { + // If the shape is known statically, just create a zero tensor + // with the right shape. + if (val_shape.is_fully_defined()) + result = array_ops.zeros(val_shape.dims, val.dtype); + else + result = array_ops.zeros_like(val, optimize: false); + } + return result; + } // def PostProcessing(self): // """Perform postprocessing at the end of gradients(). diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs index a37a6566..e17ab8ba 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs @@ -16,41 +16,16 @@ using System; using System.Collections; +using System.Collections.Generic; +using static Tensorflow.Binding; namespace Tensorflow.Operations.ControlFlows { + /// + /// The state used for constructing the gradient graph for a while loop. + /// public class GradLoopState { - - //class GradLoopState(object): - // """The state used for constructing the gradient graph for a while loop. - - // We create a GradLoopState for each while loop in forward and its - // corresponding while loop in backprop. This gives us access to both - // the forward and the backprop WhileContexts. - - // During the construction of gradient graph, any time when we detect - // a forward value that is needed for backprop, we create a history - // accumulator and add it to `history_map`. Any time when we backprop - // a loop switch op (in _SwitchGrad), we add the grad merge op in - // `switch_map`. - // """ - - // def __init__(self, forward_ctxt, outer_grad_state): - // # The grad loop state for the outer while loop. - // self._outer_grad_state = None - - // # The while loop context for forward. - // self._forward_context = None - - // # The loop counter added by AddForwardLoopCounter. It is the value - // # of the loop counter for the next iteration. - // self._forward_index = None - - // # A sync op for forward. - // self._forward_sync = None - - // # The while loop context for backprop. private WhileContext _grad_context = null; public WhileContext grad_context => _grad_context; @@ -65,156 +40,91 @@ namespace Tensorflow.Operations.ControlFlows // # Information needed by backprop. private Hashtable _history_map = new Hashtable(); public Hashtable history_map => _history_map; - private Hashtable _switch_map = new Hashtable(); - public Hashtable switch_map => _switch_map; - // self._unused_exits = [] - // self._deferred_exits = [] - // self._forward_loop_exits = list(forward_ctxt.loop_exits) - // self._pending_exits_count = len(forward_ctxt.loop_exits) - - // self._outer_grad_state = outer_grad_state - // if outer_grad_state: - // outer_forward_ctxt = outer_grad_state.forward_context - // else: - // if not hasattr(forward_ctxt, "outer_context"): - // raise ValueError("Failed to call gradients on a while loop without" - // "properly serializing graph via MetaGraphDef") - // outer_forward_ctxt = forward_ctxt.outer_context - - // # Add the forward loop counter. - // with forward_ctxt._graph.as_default(): # pylint: disable=protected-access - // if outer_forward_ctxt: - // outer_forward_ctxt.Enter() - // cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) - // if outer_forward_ctxt: - // outer_forward_ctxt.Exit() - // self._forward_context = forward_ctxt - // self._forward_index = forward_index - - // # Add the backprop WhileContext, and the backprop loop counter. - // if outer_grad_state: - // # This is a nested loop. Remember the iteration counts for each - // # execution of this inner loop. - // outer_forward_ctxt.AddName(cnt.name) - // history_cnt = outer_grad_state.AddForwardAccumulator(cnt) - - // outer_grad_ctxt = outer_grad_state.grad_context - // outer_grad_ctxt.Enter() - // self._grad_context = WhileContext( - // maximum_iterations=forward_ctxt.maximum_iterations, - // parallel_iterations=forward_ctxt.parallel_iterations, - // back_prop=forward_ctxt.back_prop, - // swap_memory=forward_ctxt.swap_memory, - // name=forward_ctxt.name, - // grad_state=self) - // real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) - // self._grad_index = self._grad_context.AddBackpropLoopCounter( - // real_cnt, outer_grad_state) - // outer_grad_ctxt.Exit() - // else: - // if outer_forward_ctxt: - // outer_forward_ctxt.Enter() - // self._grad_context = WhileContext( - // maximum_iterations=forward_ctxt.maximum_iterations, - // parallel_iterations=forward_ctxt.parallel_iterations, - // back_prop=forward_ctxt.back_prop, - // swap_memory=forward_ctxt.swap_memory, - // name=forward_ctxt.name, - // grad_state=self) - // self._grad_index = self._grad_context.AddBackpropLoopCounter( - // cnt, outer_grad_state) - // if outer_forward_ctxt: - // outer_forward_ctxt.Exit() - - // @property - // def outer_grad_state(self): - // """The grad loop state for outer loop.""" - // return self._outer_grad_state - - // @property - // def forward_context(self): - // """The while loop context for forward.""" - // return self._forward_context - - // @property - // def forward_index(self): - // """The loop index of forward loop.""" - // return self._forward_index - - // @property - // def forward_sync(self): - // """A control trigger node for synchronization in the forward loop. - - // One main use is to keep the push ops of a stack executed in the - // iteration order. - // """ - // if self._forward_sync is None: - // with ops.control_dependencies(None): - // self._forward_sync = control_trigger(name="f_sync") - // self._forward_sync._set_control_flow_context(self._forward_context) - // self._forward_index.op._add_control_input(self._forward_sync) - // return self._forward_sync - - // @property - // def grad_context(self): - // """The corresponding WhileContext for gradient.""" - // return self._grad_context - - // @property - // def grad_index(self): - // """The loop index of backprop loop.""" - // return self._grad_index - - // @property - // def grad_sync(self): - // """A control trigger node for synchronization in the grad loop. - - // One main use is to keep the pop ops of a stack executed in the - // iteration order. - // """ - // if self._grad_sync is None: - // with ops.control_dependencies(None): - // self._grad_sync = control_trigger(name="b_sync") - // self._grad_sync._set_control_flow_context(self._grad_context) - // self._grad_index.op._add_control_input(self._grad_sync) - // if self._grad_context.outer_context: - // self._grad_context.outer_context.AddInnerOp(self._grad_sync) - // return self._grad_sync - - // @property - // def history_map(self): - // """The map that records all the tensors needed for backprop.""" - // return self._history_map - - // @property - // def switch_map(self): - // """The map that records all the Switch ops for the while loop.""" - // return self._switch_map - - // @property - // def unused_exits(self): - // """The list of "unused" exits.""" - // return self._unused_exits - - // @property - // def deferred_exits(self): - // """The list of "deferred" exits.""" - // return self._deferred_exits - - // @property - // def forward_loop_exits(self): - // """The list of exits of the forward loop.""" - // return self._forward_loop_exits - - // @property - // def pending_exits_count(self): - // """The number of exits we expect to see but haven't.""" - // return self._pending_exits_count - - // @pending_exits_count.setter - // def pending_exits_count(self, cnt): - // """Set the pending count to cnt.""" - // self._pending_exits_count = cnt + Dictionary _switch_map = new Dictionary(); + public Dictionary switch_map => _switch_map; + + /// + /// The while loop context for forward. + /// + WhileContext _forward_context; + public WhileContext forward_context => _forward_context; + + /// + /// The grad loop state for the outer while loop. + /// + GradLoopState _outer_grad_state; + public GradLoopState outer_grad_state => _outer_grad_state; + + Tensor _forward_index; + Tensor _grad_index; + + Tensor[] _forward_loop_exits; + /// + /// The list of exits of the forward loop. + /// + public Tensor[] forward_loop_exits => _forward_loop_exits; + + List _deferred_exits; + public List deferred_exits => _deferred_exits; + + List _unused_exits; + public List unused_exits => _unused_exits; + + /// + /// The number of exits we expect to see but haven't. + /// + public int pending_exits_count { get; set; } + + public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) + { + // Information needed by backprop. + _unused_exits = new List(); + _deferred_exits = new List(); + _forward_loop_exits = list(forward_ctxt.loop_exits); + pending_exits_count = len(forward_ctxt.loop_exits); + + _outer_grad_state = outer_grad_state_; + + ControlFlowContext outer_forward_ctxt = null; + if (outer_grad_state_ != null) + outer_forward_ctxt = outer_grad_state_.forward_context; + + // Add the forward loop counter. + // with forward_ctxt._graph.as_default(): + Tensor cnt, forward_index; + { + if (outer_forward_ctxt != null) + outer_forward_ctxt.Enter(); + (cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state); + if (outer_forward_ctxt != null) + outer_forward_ctxt.Exit(); + } + _forward_context = forward_ctxt; + _forward_index = forward_index; + + // Add the backprop WhileContext, and the backprop loop counter. + if (outer_grad_state != null) + { + // This is a nested loop. Remember the iteration counts for each + // execution of this inner loop. + throw new NotImplementedException("GradLoopState"); + } + else + { + if (outer_forward_ctxt != null) + outer_forward_ctxt.Enter(); + _grad_context = new WhileContext( + maximum_iterations: forward_ctxt.maximum_iterations, + parallel_iterations: forward_ctxt.parallel_iterations, + back_prop: forward_ctxt.back_prop, + swap_memory: forward_ctxt.swap_memory, + name: forward_ctxt.name, + grad_state: this); + _grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state); + if (outer_forward_ctxt != null) + outer_forward_ctxt.Exit(); + } + } /// /// Add an accumulator for each forward tensor that is needed in backprop. diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs new file mode 100644 index 00000000..5b6ae944 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public class MergeOutput + { + Tensor output; + Tensor value_index; + public MergeOutput(Tensor[] values) + { + output = values[0]; + value_index = values[1]; + } + + public Tensor this[int idx] + { + get + { + switch(idx) + { + case 0: + return output; + case 1: + return value_index; + default: + return null; + } + } + } + + public static implicit operator Tensor(MergeOutput merge) + => merge.output; + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 715c68c6..56bcf897 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -32,12 +32,17 @@ namespace Tensorflow.Operations bool _back_prop=true; GradLoopState _grad_state =null; Tensor _maximum_iterations; + public Tensor maximum_iterations => _maximum_iterations; int _parallel_iterations; + public int parallel_iterations => _parallel_iterations; bool _swap_memory; + public bool swap_memory => _swap_memory; Tensor _pivot_for_pred; Tensor _pivot_for_body; List _loop_exits; + public List loop_exits => _loop_exits; List _loop_enters; + public List loop_enters => _loop_enters; Graph _graph; public override GradLoopState grad_state => _grad_state; public override bool back_prop => _back_prop; @@ -109,7 +114,7 @@ namespace Tensorflow.Operations /// /// Add the loop termination condition and body to the graph. /// - internal Tensor[] BuildLoop(Func, Tensor> pred, + internal LoopVar BuildLoop(Func, Tensor> pred, Func, LoopVar> body, LoopVar loop_vars, TensorShape[] shape_invariants, @@ -132,14 +137,16 @@ namespace Tensorflow.Operations pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); Exit(); - var flat_result = original_body_result; + var flat_result = nest.flatten2(original_body_result) + .Select(x => x as ITensorOrTensorArray) + .ToArray(); var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); - var packed_exit_vars = nest.pack_sequence_as( + var packed_exit_vars = nest.pack_sequence_as2( structure: original_body_result, flat_sequence: exit_vars_with_tensor_arrays); - return packed_exit_vars as Tensor[]; + return packed_exit_vars; } private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array) @@ -167,7 +174,7 @@ namespace Tensorflow.Operations /// /// /// - private (Tensor[], Tensor[]) _BuildLoop(Func, Tensor> pred, + private (LoopVar, Tensor[]) _BuildLoop(Func, Tensor> pred, Func, LoopVar> body, LoopVar original_loop_vars, Tensor[] loop_vars, @@ -221,6 +228,7 @@ namespace Tensorflow.Operations var merge_vars = enter_vars .Select(x => merge(new[] { x, x })) + .Select(m => (Tensor)m) .ToArray(); _pivot_for_pred = merge_vars[0]; @@ -250,13 +258,15 @@ namespace Tensorflow.Operations var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); // Store body_result to keep track of TensorArrays returned by body - var original_body_result = new[] { body_result }; + var original_body_result = body_result; // Convert TensorArrays returned by body into their flow variables - var result = new[] { body_result }; - + var result = nest.flatten2(body_result) + .Select(x => _convert_tensorarray_to_flow(x)) + .ToArray(); + // result = ops.convert_n_to_tensor_or_composite(result); var next_vars = new List(); - //foreach (var (m, v) in zip(merge_vars, result)) - //next_vars.Add(_AddNextAndBackEdge(m, v)); + foreach (var (m, v) in zip(merge_vars, result)) + next_vars.Add(_AddNextAndBackEdge(m, v)); // Add the exit ops. var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); @@ -264,7 +274,7 @@ namespace Tensorflow.Operations // Exit the loop. // ExitResult(exit_vars); - return (null, exit_vars.ToArray()); + return (original_body_result, exit_vars.ToArray()); } private void _FixControlInputsAndContext(Tensor[] enters) @@ -282,7 +292,18 @@ namespace Tensorflow.Operations var keep_as_control_input = true; var op_ctxt = control_flow_util.GetOutputContext(op); var outer_ctxt = outer_context; - throw new NotImplementedException(""); + var outer_while_context = outer_ctxt == null ? null : outer_ctxt.GetWhileContext(); + while (outer_ctxt != op_ctxt) + { + if (outer_ctxt == null || outer_ctxt == outer_while_context) + { + keep_as_control_input = false; + break; + } + outer_ctxt = outer_ctxt.outer_context; + } + if (keep_as_control_input) + outer_control_inputs.append(op); } // op for op in control_inputs if self._IsInOuterContext(op) /*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) @@ -307,10 +328,21 @@ namespace Tensorflow.Operations protected override void _AddOpInternal(Operation op) { + if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad") + { + + } + Operation[] external_inputs = new Operation[0]; + Operation[] control_inputs = new Operation[0]; if (op.inputs.Length == 0) { - throw new NotImplementedException(""); + // Remove any external control dependency on this op + (control_inputs, external_inputs) = _RemoveExternalControlEdges(op); + if (control_inputs.Length == 0) + op._add_control_input(GetControlPivot().op); + foreach (var x in op.outputs) + _values.Add(x.name); } else { @@ -378,6 +410,93 @@ namespace Tensorflow.Operations _AddOpInternal(op); } + /// + /// Adds a loop that counts the number of iterations. + /// + /// The outer grad state. None if not nested. + /// The number of iterations taken by the forward loop and the loop index. + public (Tensor, Tensor) AddForwardLoopCounter(GradLoopState outer_grad_state) + { + var n = constant_op.constant(0, name: "f_count"); + if (outer_grad_state != null) + throw new NotImplementedException("AddForwardLoopCounter"); + + Enter(); + AddName(n.name); + var enter_n = _Enter(n, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + name: "f_count"); + _loop_enters.Add(enter_n); + + var m1 = merge(new[] { enter_n, enter_n }); + var merge_n = m1[0]; + var switch_n = @switch (merge_n, _pivot); + + var index = math_ops.add(switch_n[1], 1); + var next_n = _NextIteration(index); + merge_n.op._update_input(1, next_n); + + var total_iterations = exit(switch_n[0], name: "f_count"); + loop_exits.append(total_iterations); + ExitResult(new[] { total_iterations }); + Exit(); + + return (total_iterations, next_n); + } + + /// + /// Add the backprop loop that controls the iterations. + /// + /// The number of iterations for backprop. + /// The outer grad state. None if not nested. + /// The loop index. + public Tensor AddBackpropLoopCounter(Tensor count, GradLoopState outer_grad_state) + { + Tensor one = null; + var in_separate_functions = count.graph != ops.get_default_graph(); + if (in_separate_functions) + // Brings the count into this graph + count = array_ops.identity(count); + else + one = constant_op.constant(1, name: "b_count"); + + Enter(); + AddName(count.name); + var enter_count = _Enter( + count, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + name: "b_count"); + loop_enters.append(enter_count); + + var merge_count = merge(new[] { enter_count, enter_count })[0]; + _pivot_for_pred = merge_count; + if (in_separate_functions) + one = constant_op.constant(1, name: "b_count"); + var pred = math_ops.greater_equal(merge_count, one); + _pivot = gen_control_flow_ops.loop_cond(pred, name: "b_count"); + var switch_count = @switch(merge_count, _pivot); + + var index = math_ops.subtract(switch_count[1], one); + _pivot_for_body = index; + var next_count = _NextIteration(index); + merge_count.op._update_input(1, next_count); + + var final_zero = exit(switch_count[0], name: "b_count"); + loop_exits.append(final_zero); + // Force the stack pops of i-th execution of an inner loop to be ordered + // before the pops of (i+1)-th execution of the same inner loop. + if (outer_grad_state != null) + throw new NotImplementedException("outer_grad_state"); + //outer_grad_state.grad_sync._add_control_input(final_zero.op); + ExitResult(new[] { final_zero }); + Exit(); + return next_count; + } + /// /// Add `val` to the current context and its outer context recursively. /// @@ -401,17 +520,27 @@ namespace Tensorflow.Operations grad_ctxt = grad_ctxt.GetWhileContext(); if (grad_ctxt.grad_state != null) { - throw new NotImplementedException(""); + var forward_ctxt = val.op.GetWhileContext(); + if (control_flow_util.IsLoopExit(val.op)) + { + forward_ctxt = forward_ctxt.outer_context as WhileContext; + if (forward_ctxt != null) + forward_ctxt = forward_ctxt.GetWhileContext(); + throw new NotImplementedException("control_flow_util.IsLoopExit"); + } + if(forward_ctxt == grad_ctxt.grad_state.forward_context) + { + throw new NotImplementedException("forward_ctxt == grad_ctxt.grad_state.forward_context"); + /*real_val = grad_ctxt.grad_state.GetRealValue(val); + _external_values[val.name] = real_val; + return real_val;*/ + } } } if (_outer_context != null) result = _outer_context.AddValue(val); - if (tf.get_default_graph()._nodes_by_name.Count >= 83) - { - - } // Create an Enter to make `result` known to this loop context. Tensor enter = null; tf_with(ops.control_dependencies(null), delegate @@ -443,6 +572,9 @@ namespace Tensorflow.Operations return result; } + public override bool IsWhileContext() + => true; + public override WhileContext GetWhileContext() { return this; From dd1b589d219fa1354a86038637ec0f3119815b5e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 08:02:07 -0500 Subject: [PATCH 25/41] Tensor.Flatten --- src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs | 15 +++++++++++++++ src/TensorFlowNET.Core/Tensors/Tensor.cs | 7 ++++++- src/TensorFlowNET.Core/Tensors/TensorArray.cs | 6 ++++++ src/TensorFlowNET.Core/Tensors/dtypes.cs | 3 +++ 4 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs new file mode 100644 index 00000000..5e729a14 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs @@ -0,0 +1,15 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class Tensor + { + public object[] Flatten() + { + return new Tensor[] { this }; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 99a373c4..9f505419 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -39,7 +39,12 @@ namespace Tensorflow /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// [SuppressMessage("ReSharper", "ConvertToAutoProperty")] - public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike, ITensorOrTensorArray, IPackable + public partial class Tensor : DisposableObject, + ITensorOrOperation, + _TensorLike, + ITensorOrTensorArray, + IPackable, + ICanBeFlattened { private readonly int _id; private readonly Operation _op; diff --git a/src/TensorFlowNET.Core/Tensors/TensorArray.cs b/src/TensorFlowNET.Core/Tensors/TensorArray.cs index fe9e2d6d..369b9dc0 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorArray.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorArray.cs @@ -61,5 +61,11 @@ namespace Tensorflow public Tensor read(Tensor index, string name = null) => _implementation.read(index, name: name); + + public TensorArray write(Tensor index, Tensor value, string name = null) + => _implementation.write(index, value, name: name); + + public Tensor stack(string name = null) + => _implementation.stack(name: name); } } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 2a2a9bfa..59de20ac 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -33,6 +33,9 @@ namespace Tensorflow public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float64 = TF_DataType.TF_DOUBLE; + public static TF_DataType complex = TF_DataType.TF_COMPLEX; + public static TF_DataType complex64 = TF_DataType.TF_COMPLEX64; + public static TF_DataType complex128 = TF_DataType.TF_COMPLEX128; public static TF_DataType variant = TF_DataType.TF_VARIANT; public static TF_DataType resource = TF_DataType.TF_RESOURCE; From 38ad490c3ed6ae116e4c088be9af28829ce4bb11 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 08:02:58 -0500 Subject: [PATCH 26/41] return array instead of tuple for layer.call --- .../Keras/Layers/BatchNormalization.cs | 4 ++-- src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 4 ++-- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 4 ++-- .../Keras/Layers/Embedding.cs | 4 ++-- src/TensorFlowNET.Core/Keras/Layers/Layer.cs | 19 ++++++++++--------- .../Keras/Layers/Pooling2D.cs | 4 ++-- 6 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 9b42eaaa..74432b2b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -139,14 +139,14 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { Tensor outputs = null; if (fused) { outputs = _fused_batch_norm(inputs, training: training); - return (outputs, outputs); + return new[] { outputs, outputs }; } throw new NotImplementedException("BatchNormalization call"); diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index ad233d6b..7f763fb8 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) @@ -126,7 +126,7 @@ namespace Tensorflow.Keras.Layers if (activation != null) outputs = activation.Activate(outputs); - return (outputs, outputs); + return new[] { outputs, outputs }; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 74778873..bfd6f2a5 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { Tensor outputs = null; var rank = inputs.rank; @@ -90,7 +90,7 @@ namespace Tensorflow.Keras.Layers if (activation != null) outputs = activation.Activate(outputs); - return (outputs, outputs); + return new[] { outputs, outputs }; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index 95544d36..89ad4a63 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -50,14 +50,14 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) inputs = math_ops.cast(inputs, tf.int32); var @out = embedding_ops.embedding_lookup(embeddings, inputs); - return (@out, @out); + return new[] { @out, @out }; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index d7d7e31a..3ab37a0b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -103,14 +103,14 @@ namespace Tensorflow.Keras.Layers _inbound_nodes = new List(); } - public (Tensor, Tensor) __call__(Tensor[] inputs, + public Tensor[] __call__(Tensor[] inputs, Tensor training = null, Tensor state = null, VariableScope scope = null) { var input_list = inputs; var input = inputs[0]; - Tensor outputs = null; + Tensor[] outputs = null; // We will attempt to build a TF graph if & only if all inputs are symbolic. // This is always the case in graph mode. It can also be the case in eager @@ -142,25 +142,26 @@ namespace Tensorflow.Keras.Layers // overridden). _maybe_build(inputs[0]); - (input, outputs) = call(inputs[0], + outputs = call(inputs[0], training: training, state: state); + (input, outputs) = _set_connectivity_metadata_(input, outputs); _handle_activity_regularization(inputs[0], outputs); _set_mask_metadata(inputs[0], outputs, null); }); } - return (input, outputs); + return outputs; } - private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs) + private (Tensor, Tensor[]) _set_connectivity_metadata_(Tensor inputs, Tensor[] outputs) { //_add_inbound_node(input_tensors: inputs, output_tensors: outputs); return (inputs, outputs); } - private void _handle_activity_regularization(Tensor inputs, Tensor outputs) + private void _handle_activity_regularization(Tensor inputs, Tensor[] outputs) { //if(_activity_regularizer != null) { @@ -168,7 +169,7 @@ namespace Tensorflow.Keras.Layers } } - private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask) + private void _set_mask_metadata(Tensor inputs, Tensor[] outputs, Tensor previous_mask) { } @@ -178,9 +179,9 @@ namespace Tensorflow.Keras.Layers return null; } - protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected virtual Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { - return (inputs, inputs); + throw new NotImplementedException(""); } protected virtual string _name_scope() diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index 81d57abe..ccb1cd6f 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { int[] pool_shape; if (data_format == "channels_last") @@ -64,7 +64,7 @@ namespace Tensorflow.Keras.Layers padding: padding.ToUpper(), data_format: conv_utils.convert_data_format(data_format, 4)); - return (outputs, outputs); + return new[] { outputs, outputs }; } } } From d2de8bed63cf8ec9dc96531c1ae9bc17b94e1214 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 08:58:39 -0500 Subject: [PATCH 27/41] tensor_array_write_v3 --- .../Operations/gen_data_flow_ops.cs | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index 1d3cc047..52b0a372 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -220,5 +220,44 @@ namespace Tensorflow return _op.output; } + + public static Tensor tensor_array_write_v3(Tensor handle, Tensor index, Tensor value, Tensor flow_in, string name = null) + { + var _op = _op_def_lib._apply_op_helper("TensorArrayWriteV3", name, new + { + handle, + index, + value, + flow_in + }); + + return _op.output; + } + + public static Tensor tensor_array_size_v3(Tensor handle, Tensor flow_in, string name = null) + { + var _op = _op_def_lib._apply_op_helper("TensorArraySizeV3", name, new + { + handle, + flow_in + }); + + return _op.output; + } + + public static Tensor tensor_array_gather_v3(Tensor handle, Tensor indices, Tensor flow_in, + TF_DataType dtype, TensorShape element_shape = null, string name = null) + { + var _op = _op_def_lib._apply_op_helper("TensorArrayGatherV3", name, new + { + handle, + indices, + dtype, + element_shape, + flow_in + }); + + return _op.output; + } } } From c8a61b21d5e378b5b9ca2a565d71e93bd7f5723c Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 08:59:21 -0500 Subject: [PATCH 28/41] _UpdatePendingAndEnqueueReady --- .../Gradients/gradients_util.cs | 94 +++++++++++++++++-- 1 file changed, 87 insertions(+), 7 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 15ad511b..163192ee 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Operations.ControlFlows; using static Tensorflow.Binding; namespace Tensorflow @@ -82,6 +83,7 @@ namespace Tensorflow var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); + // Add the initial gradients for the ys. foreach (var (y, grad_y) in zip(ys, grad_ys)) _SetGrad(grads, y, grad_y); @@ -103,12 +105,25 @@ namespace Tensorflow } } + if(loop_state != null) + { + var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set); + foreach(var y in loop_exits) + { + //if(IsTrainable(y)) + throw new NotImplementedException(""); + } + } + var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); while (queue.Count > 0) { // generate gradient subgraph for op. var op = queue.Dequeue(); + if(op.name == "rnn/while/basic_rnn_cell/Tanh") + { + } _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); //if (loop_state != null) //loop_state.EnterGradWhileContext(op, before: true); @@ -147,8 +162,8 @@ namespace Tensorflow } } - // if (loop_state) - //loop_state.EnterGradWhileContext(op, before: false); + if (loop_state != null) + loop_state.EnterGradWhileContext(op, before: false); if ((is_func_call || grad_fn != null) && has_out_grads) { @@ -164,7 +179,7 @@ namespace Tensorflow // will use SymbolicGradient get a zero gradient. Gradient // functions should ignore the gradient for other outputs. if (loop_state != null) - ; + out_grads[i] = new List { loop_state.ZerosLike(op, i) }; else out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; } @@ -275,7 +290,7 @@ namespace Tensorflow /// /// /// - private static (Operation[], Dictionary, object) _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, Tensor[] xs) + private static (Operation[], Dictionary, ControlFlowState) _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, Tensor[] xs) { // Mark reachable ops from from_ops. var reached_ops = new List(); @@ -308,6 +323,7 @@ namespace Tensorflow // 'loop_state' is None if there are no while loops. var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); + // Initialize pending count for between ops. var pending_count = new Dictionary(); foreach (var op in between_op_list) { @@ -550,7 +566,7 @@ namespace Tensorflow Operation op, Queue queue, Dictionary pending_count, - object loop_state, + ControlFlowState loop_state, Tensor[] xs) { foreach (var x in _NonEagerInputs(op, xs)) @@ -564,14 +580,49 @@ namespace Tensorflow if (loop_state != null && !ready) { - + ready = pending_count[x.op.name] > 0 && control_flow_util.IsLoopSwitch(x.op); } if (ready) { + // if x is an exit without real gradient, defer processing them. if (control_flow_util.IsLoopExit(x.op)) { - + var grad_state = loop_state.GetGradState(x.op, before: false); + grad_state.deferred_exits.append(x); + grad_state.pending_exits_count -= 1; + // We now have all the exits so process them. + if (grad_state.pending_exits_count == 0) + { + var has_not_none_grad = false; + foreach(var y in grad_state.deferred_exits) + { + if (_HasAnyNotNoneGrads(grads, y.op)) + { + has_not_none_grad = true; + queue.Enqueue(y.op); + } + else + grad_state.unused_exits.append(y); + } + if (has_not_none_grad) + { + // For an unused exit, if it has trainable outputs, backprop + // a zero gradient. Otherwise, just ignore it. + foreach (var y in grad_state.unused_exits) + { + if (IsTrainable(y)) + _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)); + queue.Enqueue(y.op); + } + } + else + { + // All exits are "unused" so use None as gradient. + foreach (var y in grad_state.unused_exits) + queue.Enqueue(y.op); + } + } } else { @@ -581,6 +632,32 @@ namespace Tensorflow } } + private static bool IsTrainable(Tensor tensor) + { + var dtype = tensor.dtype.as_base_dtype(); + return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, + dtypes.complex64, dtypes.complex128, + dtypes.resource, dtypes.variant}.Contains(dtype); + } + + /// + /// Return true if op has real gradient. + /// + /// + /// + /// + private static bool _HasAnyNotNoneGrads(Dictionary>> grads, Operation op) + { + var out_grads = _GetGrads(grads, op); + foreach(var out_grad in out_grads) + { + if (out_grad.Exists(g => g != null)) + return true; + } + return false; + } + + private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func grad_fn) { scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; @@ -589,6 +666,9 @@ namespace Tensorflow private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) { + if (op.type == "While" || op.type == "StatelessWhile") + return; + if (grads.Count() != op.inputs._inputs.Count()) throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + $"inputs {op.inputs._inputs.Count()}"); From a8a515682ef60b4b7dac447c1520af420b3c70e8 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 08:59:36 -0500 Subject: [PATCH 29/41] _SwitchGrad --- .../Gradients/control_flow_grad.cs | 141 +++++++++--------- 1 file changed, 74 insertions(+), 67 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs index 76b6a7b5..acaa6de3 100644 --- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs @@ -45,7 +45,19 @@ namespace Tensorflow.Gradients switch (op_ctxt) { case WhileContext cwhile: - throw new NotImplementedException("_SwitchGrad WhileContext"); + { + var merge_grad = grad_ctxt.grad_state.switch_map.get(op); + if (merge_grad != null) + throw new NotImplementedException("_SwitchGrad merge_grad != null"); + else if (grads[0] != null) + { + merge_grad = merge(new[] { grads[0], grads[0] }, name: "b_switch")[0]; + grad_ctxt.grad_state.switch_map[op] = merge_grad; + return new Tensor[] { merge_grad, null }; + } + else + return new Tensor[] { null, null }; + } case CondContext ccond: { var zero_grad = grads[1 - op_ctxt.branch]; @@ -74,7 +86,7 @@ namespace Tensorflow.Gradients /// /// /// - internal static Tensor[] merge(Tensor[] inputs, string name = null) + internal static MergeOutput merge(Tensor[] inputs, string name = null) { return tf_with(ops.name_scope(name, "Merge", inputs), scope => { @@ -146,7 +158,7 @@ namespace Tensorflow.Gradients } [RegisterGradient("RefMerge")] - public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) + public static Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) { return _MergeGrad(op, grads); } @@ -155,43 +167,32 @@ namespace Tensorflow.Gradients /// Gradients for an exit op are calculated using an Enter op. /// [RegisterGradient("Exit")] - public Tensor[] _ExitGrad(Operation op, Tensor[] grads) + public static Tensor[] _ExitGrad(Operation op, Tensor[] grads) { - throw new NotImplementedException("_ExitGrad"); - // graph = ops.get_default_graph() - //# pylint: disable=protected-access - // op_ctxt = op._get_control_flow_context() - // grad_ctxt = graph._get_control_flow_context() - // # pylint: enable=protected-access - // if not grad_ctxt.back_prop: - // # The flag `back_prop` is set by users to suppress gradient - // # computation for this loop. If the attribute `back_prop` is false, - // # no gradient computation. - // return None + var grad = grads[0]; + var graph = ops.get_default_graph(); + var op_ctxt = op._get_control_flow_context(); + var grad_ctxt = graph._get_control_flow_context() as WhileContext; + // The flag `back_prop` is set by users to suppress gradient + // computation for this loop. If the attribute `back_prop` is false, + // no gradient computation. + if (!grad_ctxt.back_prop) + return null; + + if (op_ctxt.grad_state != null) + throw new TypeError("Second-order gradient for while loops not supported."); + + grad_ctxt.AddName(grad.name); - // if op_ctxt.grad_state: - // raise TypeError("Second-order gradient for while loops not supported.") + grad_ctxt.Enter(); + var result = control_flow_ops._Enter( + grad, grad_ctxt.name, is_constant: false, + parallel_iterations: grad_ctxt.parallel_iterations, + name: "b_exit"); - // if isinstance(grad, ops.Tensor) : - // grad_ctxt.AddName(grad.name) - // else: - // if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)): - // raise TypeError("Type %s not supported" % type(grad)) - // grad_ctxt.AddName(grad.values.name) - // grad_ctxt.AddName(grad.indices.name) - // dense_shape = grad.dense_shape - // if dense_shape is not None: - // grad_ctxt.AddName(dense_shape.name) - // grad_ctxt.Enter() - // # pylint: disable=protected-access - // result = control_flow_ops._Enter( - // grad, grad_ctxt.name, is_constant=False, - // parallel_iterations=grad_ctxt.parallel_iterations, - // name="b_exit") - // # pylint: enable=protected-access - // grad_ctxt.loop_enters.append(result) - // grad_ctxt.Exit() - // return result + grad_ctxt.loop_enters.append(result); + grad_ctxt.Exit(); + return new[] { result }; } /// @@ -200,15 +201,15 @@ namespace Tensorflow.Gradients /// Note that the backprop next_iteration is added in switch grad. /// [RegisterGradient("NextIteration")] - public Tensor[] _NextIterationGrad(object _, Tensor[] grad) + public static Tensor[] _NextIterationGrad(Operation op, Tensor[] grads) { - return grad; + return grads; } [RegisterGradient("RefNextIteration")] - public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad) + public static Tensor[] _RefNextIterationGrad(Operation op, Tensor[] grads) { - return grad; + return grads; } /// @@ -218,33 +219,39 @@ namespace Tensorflow.Gradients /// For loop invariants, we need to add an accumulator loop. /// [RegisterGradient("Enter")] - public Tensor[] _EnterGrad(Tensor op, Tensor[] grad) + public static Tensor[] _EnterGrad(Operation op, Tensor[] grads) { - throw new NotImplementedException("_EnterGrad"); - // graph = ops.get_default_graph() - //# pylint: disable=protected-access - // grad_ctxt = graph._get_control_flow_context() - // # pylint: enable=protected-access - // if not grad_ctxt.back_prop: - // # Skip gradient computation, if the attribute `back_prop` is false. - // return grad - // if grad_ctxt.grad_state is None: - // # Pass the gradient through if we are not in a gradient while context. - // return grad - // if op.get_attr("is_constant"): - // # Add a gradient accumulator for each loop invariant. - // if isinstance(grad, ops.Tensor) : - // result = grad_ctxt.AddBackpropAccumulator(op, grad) - // elif isinstance(grad, ops.IndexedSlices) : - // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) - // else: - // # TODO(yuanbyu, lukasr): Add support for SparseTensor. - // raise TypeError("Type %s not supported" % type(grad)) - // else: - // result = exit(grad) - // grad_ctxt.loop_exits.append(result) - // grad_ctxt.ExitResult([result]) - // return result + Tensor result = null; + var grad = grads[0]; + var graph = ops.get_default_graph(); + var grad_ctxt = graph._get_control_flow_context() as WhileContext; + if (!grad_ctxt.back_prop) + // Skip gradient computation, if the attribute `back_prop` is false. + return grads; + if (grad_ctxt.grad_state == null) + // Pass the gradient through if we are not in a gradient while context. + return grads; + if (op.get_attr("is_constant")) + { + throw new NotImplementedException("_EnterGrad is_constant"); + // Add a gradient accumulator for each loop invariant. + // if isinstance(grad, ops.Tensor) : + // result = grad_ctxt.AddBackpropAccumulator(op, grad) + // elif isinstance(grad, ops.IndexedSlices) : + // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) + // else: + // # TODO(yuanbyu, lukasr): Add support for SparseTensor. + // raise TypeError("Type %s not supported" % type(grad)) + } + + else + { + result = control_flow_ops.exit(grad); + grad_ctxt.loop_exits.append(result); + grad_ctxt.ExitResult(new[] { result }); + } + + return new Tensor[] { result }; } From 1ef2ec1ca19dddd0cfb2454de155abdfe291b4a4 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 11:17:32 -0500 Subject: [PATCH 30/41] GraphTensorArray write, size, stack. --- .../Operations/BasicRNNCell.cs | 6 ++- .../Operations/_GraphTensorArray.cs | 54 ++++++++++++++++++- .../Operations/tensor_array_ops.cs | 19 +++++++ 3 files changed, 76 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs index fdcc03ea..69f86349 100644 --- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs @@ -66,12 +66,14 @@ namespace Tensorflow built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { // Most basic RNN: output = new_state = act(W * input + U * state + B). var concat = array_ops.concat(new[] { inputs, state }, 1); var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable); - return (inputs, inputs); + gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); + var output = _activation(gate_inputs, null); + return new[] { output, output }; } } } diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index 56ac277e..ea701afc 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -22,7 +22,7 @@ using static Tensorflow.Binding; namespace Tensorflow.Operations { - internal class _GraphTensorArray + public class _GraphTensorArray { internal TF_DataType _dtype; public TF_DataType dtype => _dtype; @@ -174,5 +174,57 @@ namespace Tensorflow.Operations return value; } + + public TensorArray write(Tensor index, Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + _maybe_colocate_with(value); + var flow_out = gen_data_flow_ops.tensor_array_write_v3( + handle: _handle, + index: index, + value: value, + flow_in: _flow, + name: name); + + return tensor_array_ops.build_ta_with_new_flow(this, flow_out); + }); + } + + private Tensor size(string name = null) + { + return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); + } + + public Tensor stack(string name = null) + { + ops.colocate_with(_handle); + return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate + { + return gather(math_ops.range(0, size()), name: name); + }); + } + + public Tensor gather(Tensor indices, string name = null) + { + var element_shape = new TensorShape(); + + if (_element_shape.Count > 0) + element_shape = _element_shape[0]; + + var value = gen_data_flow_ops.tensor_array_gather_v3( + handle: _handle, + indices: indices, + flow_in: _flow, + dtype: _dtype, + name: name, + element_shape: element_shape); + + //if (element_shape != null) + //value.set_shape(-1, element_shape.dims); + + return value; + } } } diff --git a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs index 8ce3b5c7..59496943 100644 --- a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Operations; namespace Tensorflow { @@ -29,5 +30,23 @@ namespace Tensorflow new_impl._element_shape = impl._element_shape; return new_ta; } + + public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow) + { + var impl = old_ta; + + var new_ta = new TensorArray( + dtype: impl.dtype, + handle: impl.handle, + flow: flow, + infer_shape: impl.infer_shape, + colocate_with_first_write_call: impl.colocate_with_first_write_call); + + var new_impl = new_ta._implementation; + new_impl._dynamic_size = impl._dynamic_size; + new_impl._colocate_with = impl._colocate_with; + new_impl._element_shape = impl._element_shape; + return new_ta; + } } } From 67949251b2935f6cb9a24d197154efa6733aa251 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 11:17:53 -0500 Subject: [PATCH 31/41] override graph --- src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs | 4 ++-- src/TensorFlowNET.Core/Sessions/_FetchHandler.cs | 2 +- src/TensorFlowNET.Core/Sessions/_FetchMapper.cs | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index bc1ea0b7..48eddf3b 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -28,9 +28,9 @@ namespace Tensorflow { private Func, object> _contraction_fn; - public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn) + public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn, Graph graph = null) { - var g = ops.get_default_graph(); + var g = graph ?? ops.get_default_graph(); foreach(var fetch in fetches) { diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index e1a77d90..b7434089 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -34,7 +34,7 @@ namespace Tensorflow public _FetchHandler(Graph graph, object fetches, Dictionary feeds = null, Action feed_handles = null) { - _fetch_mapper = _FetchMapper.for_fetch(fetches); + _fetch_mapper = _FetchMapper.for_fetch(fetches, graph: graph); foreach(var fetch in _fetch_mapper.unique_fetches()) { switch (fetch) diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index 534cdcd7..e28b76a1 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -25,7 +25,7 @@ namespace Tensorflow { protected List _unique_fetches = new List(); protected List _value_indices = new List(); - public static _FetchMapper for_fetch(object fetch) + public static _FetchMapper for_fetch(object fetch, Graph graph = null) { var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; @@ -34,7 +34,7 @@ namespace Tensorflow if (fetch.GetType().IsArray) return new _ListFetchMapper(fetches); else - return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0]); + return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0], graph: graph); } public virtual NDArray[] build_results(List values) From 59b7eb0365cbf6afca252de3d2aef73e6361ce3b Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 11:18:37 -0500 Subject: [PATCH 32/41] MaybeCreateControlFlowState --- .../Gradients/gradients_util.cs | 178 ++++++++++-------- src/TensorFlowNET.Core/Layers/Layer.cs | 7 +- .../Operations/control_flow_ops.cs | 58 ++++-- .../Operations/control_flow_util.py.cs | 95 +++++++++- src/TensorFlowNET.Core/Operations/math_ops.cs | 2 + 5 files changed, 232 insertions(+), 108 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 163192ee..8170ea6f 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -55,6 +55,9 @@ namespace Tensorflow * is more than one. **/ var grads = new Dictionary>>(); + Operation[] reachable_to_ops = null; + ControlFlowState loop_state = null; + Dictionary pending_count = null; tf_with(ops.name_scope(name, "gradients", values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => @@ -81,7 +84,7 @@ namespace Tensorflow var to_ops = ys.Select(x => x.op).ToList(); var from_ops = xs.Select(x => x.op).ToList(); var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); - var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); + (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); // Add the initial gradients for the ys. foreach (var (y, grad_y) in zip(ys, grad_ys)) @@ -120,126 +123,135 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); - if(op.name == "rnn/while/basic_rnn_cell/Tanh") + if(op.name == "rnn/while/Exit") { } _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); - //if (loop_state != null) - //loop_state.EnterGradWhileContext(op, before: true); - var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); - - Tensor[] in_grads = null; - var is_partitioned_call = _IsPartitionedCall(op); - var is_func_call = false; - var has_out_grads = out_grads.Exists(x => x != null); - if (has_out_grads && !stop_ops.Contains(op)) { - // A grad_fn must be defined, either as a function or as None - // for ops that do not have gradients. + if (loop_state != null) + loop_state.EnterGradWhileContext(op, before: true); + var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); + if (loop_state != null) + loop_state.ExitGradWhileContext(op, before: true); - Func grad_fn = null; - try - { - grad_fn = ops.get_gradient_function(op); - } - catch (LookupError) + Tensor[] in_grads = null; + var is_partitioned_call = _IsPartitionedCall(op); + var is_func_call = false; + var has_out_grads = out_grads.Exists(x => x != null); + if (has_out_grads && !stop_ops.Contains(op)) { - if (is_func_call) + // A grad_fn must be defined, either as a function or as None + // for ops that do not have gradients. + + Func grad_fn = null; + try { - if (is_partitioned_call) + grad_fn = ops.get_gradient_function(op); + } + catch (LookupError) + { + if (is_func_call) { + if (is_partitioned_call) + { + + } + else + { + } } else { - + throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); } } - else - { - throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); - } - } - if (loop_state != null) - loop_state.EnterGradWhileContext(op, before: false); + if (loop_state != null) + loop_state.EnterGradWhileContext(op, before: false); - if ((is_func_call || grad_fn != null) && has_out_grads) - { - // NOTE: If _AggregatedGrads didn't compute a value for the i'th - // output, it means that the cost does not depend on output[i], - // therefore dC/doutput[i] is 0. - foreach (var (i, out_grad) in enumerate(out_grads)) + if ((is_func_call || grad_fn != null) && has_out_grads) { - if (out_grad == null && - (grad_fn == null || _IsTrainable(op.outputs[i]))) + // NOTE: If _AggregatedGrads didn't compute a value for the i'th + // output, it means that the cost does not depend on output[i], + // therefore dC/doutput[i] is 0. + foreach (var (i, out_grad) in enumerate(out_grads)) { - // Only trainable outputs or outputs for a function call that - // will use SymbolicGradient get a zero gradient. Gradient - // functions should ignore the gradient for other outputs. - if (loop_state != null) - out_grads[i] = new List { loop_state.ZerosLike(op, i) }; - else - out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; + if (out_grad == null && + (grad_fn == null || _IsTrainable(op.outputs[i]))) + { + // Only trainable outputs or outputs for a function call that + // will use SymbolicGradient get a zero gradient. Gradient + // functions should ignore the gradient for other outputs. + if (loop_state != null) + out_grads[i] = new List { loop_state.ZerosLike(op, i) }; + else + out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; + } } - } - tf_with(ops.name_scope(op.name + "_grad"), scope1 => - { - if (grad_fn != null) + tf_with(ops.name_scope(op.name + "_grad"), scope1 => { - in_grads = _MaybeCompile(grad_scope, - op, - out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), - null, - grad_fn); - } - else - { - throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); - } - _VerifyGeneratedGradients(in_grads, op); - if (gate_gradients && in_grads.Count(x => x != null) > 1) - { - ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); - in_grads = control_flow_ops.tuple(in_grads); - } - }); + if (grad_fn != null) + { + in_grads = _MaybeCompile(grad_scope, + op, + out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), + null, + grad_fn); + } + else + { + throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); + } + _VerifyGeneratedGradients(in_grads, op); + if (gate_gradients && in_grads.Count(x => x != null) > 1) + { + ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); + in_grads = control_flow_ops.tuple(in_grads); + } + }); + } + else + { + // If no grad_fn is defined or none of out_grads is available, + // just propagate a list of None backwards. + in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; + } } else { - // If no grad_fn is defined or none of out_grads is available, - // just propagate a list of None backwards. in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; } - } - else - { - in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; - } - var inputs = _NonEagerInputs(op, xs).ToList(); - foreach (var (t_in, in_grad) in zip(inputs, in_grads)) - { - if (in_grad != null) + var inputs = _NonEagerInputs(op, xs).ToList(); + foreach (var (t_in, in_grad) in zip(inputs, in_grads)) { - if (!(in_grad is null) && - in_grad.Tag == null && // maybe a IndexedSlice - t_in.dtype != TF_DataType.TF_RESOURCE) + if (in_grad != null) { - in_grad.set_shape(t_in.TensorShape); - } + if (!(in_grad is null) && + in_grad.Tag == null && // maybe a IndexedSlice + t_in.dtype != TF_DataType.TF_RESOURCE) + { + in_grad.set_shape(t_in.TensorShape); + } - _SetGrad(grads, t_in, in_grad); + _SetGrad(grads, t_in, in_grad); + } } - } + if (loop_state != null) + loop_state.ExitGradWhileContext(op, before: false); + } + // Update pending count for the inputs of op and enqueue ready ops. _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); } }); + if (loop_state != null) + loop_state.PostProcessing(); return xs.Select(x => _GetGrad(grads, x)).ToArray(); } diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index d7cda786..39561990 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -50,10 +50,11 @@ namespace Tensorflow.Layers public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null) { - return __call__(inputs, training: training); + var results = __call__(inputs, training: training); + return (results[0], results[1]); } - public (Tensor, Tensor) __call__(Tensor inputs, + public Tensor[] __call__(Tensor inputs, Tensor training = null, Tensor state = null, VariableScope scope = null) @@ -73,7 +74,7 @@ namespace Tensorflow.Layers auxiliary_name_scope: false); } - (Tensor, Tensor) outputs = (null, null); + Tensor[] outputs = null; tf_with(scope_context_manager, scope2 => { _current_scope = scope2; diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 13182dfd..b8360939 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -151,27 +151,50 @@ namespace Tensorflow /// public static ControlFlowState MaybeCreateControlFlowState(List between_op_list, List between_ops, bool colocate_gradients_with_ops) { + var flag = new List(); ControlFlowState loop_state = null; - foreach (var op in between_op_list) + int pos = 0; + while(pos < between_op_list.Count) { + var op = between_op_list[pos]; if (IsLoopExit(op)) { - if(loop_state == null) + if (loop_state == null) { loop_state = new ControlFlowState(); } + if (colocate_gradients_with_ops) + ops.colocate_with(op); + loop_state.AddWhileContext(op, between_op_list, between_ops); } + pos++; } return loop_state; } public static bool IsLoopExit(Operation op) + => op.OpType == "Exit" || op.OpType == "RefExit"; + + public static bool IsLoopSwitch(Operation op) + { + if(IsSwitch(op)) + { + var ctxt = op._get_control_flow_context(); + return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op); + } + return false; + } + + public static bool IsCondSwitch(Operation op) { - return op.OpType == "Exit" || op.OpType == "RefExit"; + throw new NotImplementedException("IsCondSwitch"); } + public static bool IsSwitch(Operation op) + => op.type == "Switch" || op.type == "RefSwitch"; + public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null) { return tf_with(ops.name_scope(name, "tuple", tensors), scope => @@ -224,15 +247,10 @@ namespace Tensorflow //TODO: missing original code //if context.executing_eagerly(): // return output_tensor - var values = new List(); - values.AddRange(dependencies); - values.Add(output_tensor); - - return tf_with(ops.name_scope(name, "control_dependency", values), scope => + return tf_with(ops.name_scope(name, "control_dependency", new { dependencies, output_tensor }), scope => { name = scope; - // TODO: missing original code - //with ops.colocate_with(output_tensor): + ops.colocate_with(output_tensor); { return tf_with(ops.control_dependencies(dependencies), ctl => { @@ -431,6 +449,7 @@ namespace Tensorflow var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) + .Select(m => (Tensor)m) .ToArray(); var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); @@ -479,6 +498,7 @@ namespace Tensorflow var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new [] { pair.Item1, pair.Item2 })) + .Select(m => (Tensor)m) .ToArray(); var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); @@ -519,7 +539,7 @@ namespace Tensorflow /// inputs: The input tensors, at most one of which is available. /// A name for this operation (optional). /// - public static Tensor merge(Tensor[] inputs, string name = null) + public static MergeOutput merge(Tensor[] inputs, string name = null) { if (inputs.Any(x => x == null)) throw new ValueError($"At least one of the merge inputs is null: {inputs}"); @@ -529,7 +549,7 @@ namespace Tensorflow inputs = inputs.Select(inp => ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) .ToArray(); - return gen_control_flow_ops.merge(inputs, name)[0]; + return gen_control_flow_ops.merge(inputs, name); }); } @@ -602,7 +622,7 @@ namespace Tensorflow /// /// /// - public static Tensor while_loop(Func cond, Func body, TItem loop_vars, + public static TItem while_loop(Func cond, Func body, TItem loop_vars, TensorShape[] shape_invariants = null, int parallel_iterations = 10, bool back_prop = true, @@ -611,7 +631,7 @@ namespace Tensorflow Tensor maximum_iterations = null, bool return_same_structure = false) { - tf_with(ops.name_scope(name, "while", loop_vars), scope => + return tf_with(ops.name_scope(name, "while", loop_vars), scope => { if (loop_vars == null) throw new ValueError("No loop variables provided"); @@ -666,13 +686,11 @@ namespace Tensorflow var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants, return_same_structure); - if (maximum_iterations != null) - return results[1]; - else - return results[0]; + //if (maximum_iterations != null) + return results.Item; + //else + //return results; }); - - throw new NotImplementedException("while_loop"); } /// diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs index 9dcfb2e1..5f3bc15c 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Linq; using Tensorflow.Operations; using static Tensorflow.Binding; @@ -60,6 +61,45 @@ namespace Tensorflow public static bool IsSwitch(Operation op) { return op.type == "Switch" || op.type == "RefSwitch"; + } + + public static WhileContext GetWhileContext(Operation op) + => op.GetWhileContext(); + + public static bool IsCondSwitch(Operation op) + { + if (!IsSwitch(op)) + return false; + if (op.outputs == null || op.outputs.Length == 0) + return false; + + // Switch nodes are not part of the cond control flow context that they + // represent, so consider the consumers of its outputs to determine if it is + // cond switch or not. A switch is a cond switch iff all its consumers are in + // cond contexts. + var is_cond_switch = true; + foreach(var o in op.outputs) + { + foreach(var c in o.consumers()) + { + var ctxt = c._get_control_flow_context(); + if (IsLoopEnter(c)) + ctxt = ctxt.outer_context; + is_cond_switch = is_cond_switch &&(ctxt != null && ctxt.IsCondContext()); + } + } + + return is_cond_switch; + } + + public static bool IsLoopSwitch(Operation op) + { + if (IsSwitch(op)) + { + var ctxt = op._get_control_flow_context(); + return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op); + } + return false; } /// @@ -87,13 +127,64 @@ namespace Tensorflow valid = true; else { - throw new NotImplementedException(""); + var while_ctxt = GetContainingWhileContext(op_ctxt); + var input_while_ctxt = GetContainingWhileContext(input_ctxt); + + if (while_ctxt == null) + { + throw new NotImplementedException("CheckInputFromValidContext"); + } + else if (IsContainingContext(while_ctxt, input_while_ctxt)) + { + // input_op is in a while loop which contains op's while loop (or not in a + // while loop at all). + valid = true; + } + else if (while_ctxt.grad_state != null && + IsContainingContext(while_ctxt.grad_state.forward_context, + input_while_ctxt)) + { + valid = true; + } + else + throw new NotImplementedException("CheckInputFromValidContext"); } if (!valid) { - throw new NotImplementedException(""); + throw new NotImplementedException("CheckInputFromValidContext"); + } + } + + public static Operation GetLoopConstantEnter(Tensor value) + { + var id_ops = new string[] { "Switch", "RefSwitch", "Identity", "RefIdentity" }; + var op = value.op; + while (id_ops.Contains(op.type)) + op = op.inputs[0].op; + return IsLoopConstantEnter(op) ? op : null; + } + + public static bool IsContainingContext(WhileContext ctxt, WhileContext maybe_containing_ctxt) + { + while(ctxt != maybe_containing_ctxt) + { + if (ctxt == null) + return false; + ctxt = ctxt.outer_context as WhileContext; + } + return true; + } + + public static WhileContext GetContainingWhileContext(ControlFlowContext ctxt, ControlFlowContext stop_ctxt = null) + { + while (ctxt != null) + { + if (ctxt.IsWhileContext() || ctxt == stop_ctxt) + return ctxt as WhileContext; + ctxt = ctxt.outer_context; } + return null; } } } diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 17cd8a99..f158ffb1 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -159,6 +159,8 @@ namespace Tensorflow }); } + public static Tensor greater_equal(Tx x, Ty y, string name = null) + => gen_math_ops.greater_equal(x, y, name: name); public static Tensor equal(Tx x, Ty y, string name = null) => gen_math_ops.equal(x, y, name: name); From efed25863a280bd2531b6ec5ad449dbb466dd03c Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Nov 2019 11:18:53 -0500 Subject: [PATCH 33/41] WhileContext --- README.md | 2 +- src/TensorFlowNET.Core/Binding.Util.cs | 17 ++ .../ControlFlows/ControlFlowState.cs | 61 +---- .../Operations/ControlFlows/GradLoopState.cs | 209 ++++++++---------- .../Operations/ControlFlows/WhileContext.cs | 5 +- .../Operations/NnOps/rnn.cs | 42 +++- .../Operations/NnOps/rnn_cell_impl.cs | 29 +++ .../Operations/Operation.Control.cs | 9 + .../Operations/Operation.Instance.cs | 21 +- .../Operations/Operation.cs | 9 +- .../Operations/gen_array_ops.cs | 2 +- .../Operations/gen_control_flow_ops.cs | 10 +- .../Operations/gen_data_flow_ops.cs | 26 +++ .../Operations/gen_math_ops.cs | 2 +- src/TensorFlowNET.Core/Util/nest.py.cs | 8 + 15 files changed, 256 insertions(+), 196 deletions(-) diff --git a/README.md b/README.md index a80191a7..8744ba72 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) [![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US) -TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). +TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). ![tensors_flowing](docs/assets/tensors_flowing.gif) diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index c70af1fd..34f227fc 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -30,6 +30,20 @@ namespace Tensorflow /// public static partial class Binding { + public static T2 get(this Dictionary dict, T1 key) + => key == null ? + default(T2) : + (dict.ContainsKey(key) ? dict[key] : default(T2)); + + public static void add(this IList list, T element) + => list.Add(element); + + public static void append(this IList list, T element) + => list.Add(element); + + public static void extend(this List list, IEnumerable elements) + => list.AddRange(elements); + private static string _tostring(object obj) { switch (obj) @@ -81,6 +95,9 @@ namespace Tensorflow throw new NotImplementedException("len() not implemented for type: " + a.GetType()); } + public static T[] list(IEnumerable list) + => list.ToArray(); + public static IEnumerable range(int end) { return Enumerable.Range(0, end); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs index d1be6f31..1d296774 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs @@ -109,11 +109,12 @@ namespace Tensorflow.Operations.ControlFlows grad_state.grad_context.Enter(); } - // def ExitGradWhileContext(self, op, before): - // """Exit the WhileContext for gradient computation.""" - // grad_state = self.GetGradState(op, before) - // if grad_state: - // grad_state.grad_context.Exit() + public void ExitGradWhileContext(Operation op, bool before) + { + var grad_state = GetGradState(op, before); + if (grad_state != null) + grad_state.grad_context.Exit(); + } // def AddWhileContext(self, op, between_op_list, between_ops): // """Add the grad state for the while loop that op belongs to. @@ -287,51 +288,9 @@ namespace Tensorflow.Operations.ControlFlows return result; } - // def PostProcessing(self): - // """Perform postprocessing at the end of gradients(). - - // We have created the gradient graph at this point. So this function - // can be used to perform any postprocessing on the gradient graph. - // We currently perform the following postprocessing: - // 1. Patch the gradient graph if the output of a loop variable - // doesn't depend on its input. - // """ - // for _, grad_state in self._map.items(): - // for _, b_merge in grad_state.switch_map.items(): - // if b_merge.op.inputs[0] == b_merge.op.inputs[1]: - // # The value of this loop variable at iteration i+1 doesn't - // # depend on its value at iteration i. So use zeros as the - // # gradients for all iterations > 0. - // dtype = b_merge.op.inputs[0].dtype - // shape = b_merge.op.inputs[0].get_shape() - // # pylint: disable=protected-access - // if shape.is_fully_defined(): - // grad_state.grad_context.Enter() - // # Create a zeros and use it for iterations > 0. - // grad_val = constant_op.constant(0, dtype=dtype, shape=shape) - // next_grad_val = _NextIteration(grad_val) - // grad_state.grad_context.Exit() - // else: - // # Create a zeros in the outer grad context. - // outer_grad_ctxt = grad_state.grad_context.outer_context - // if outer_grad_ctxt: - // outer_grad_ctxt.Enter() - // enter_grad_op = b_merge.op.inputs[0].op - // enter_grad = enter_grad_op.inputs[0] - // grad_shape = array_ops.shape_internal(enter_grad, optimize=False) - // grad_val = array_ops.zeros(grad_shape) - // if outer_grad_ctxt: - // outer_grad_ctxt.Exit() - // # Use the zeros for iterations > 0. - // grad_state.grad_context.Enter() - // next_grad_val = _NextIteration(grad_val) - // grad_state.grad_context.Exit() - // b_merge.op._update_input(1, next_grad_val) - // # pylint: enable=protected-access - + public void PostProcessing() + { + throw new NotImplementedException("PostProcessing"); + } } - - - - } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs index e17ab8ba..143aacb1 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs @@ -17,7 +17,9 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Linq; using static Tensorflow.Binding; +using util = Tensorflow.control_flow_util; namespace Tensorflow.Operations.ControlFlows { @@ -56,6 +58,7 @@ namespace Tensorflow.Operations.ControlFlows public GradLoopState outer_grad_state => _outer_grad_state; Tensor _forward_index; + public Tensor forward_index => _forward_index; Tensor _grad_index; Tensor[] _forward_loop_exits; @@ -152,63 +155,52 @@ namespace Tensorflow.Operations.ControlFlows /// The stack that contains the accumulated history of the tensor. public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) { - throw new NotImplementedException("AddForwardAccumulator"); - // # curr_ctxt is the context that tf.gradients was called in. - // with self._forward_index.graph.as_default(): - // curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access - // with ops.control_dependencies(None): - // if curr_ctxt: - // curr_ctxt.Enter() - // with ops.colocate_with(value): - // # We only need to pass maximum_iterations to the stack if - // # we're inside an XLA context. - // if not util.IsInXLAContext(value.op): - // max_size = constant_op.constant(-1, dtypes.int32) - // else: - // max_size = GetMaxSizeFromNestedMaximumIterations( - // value, self.forward_context) - // acc = gen_data_flow_ops.stack_v2( - // max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") - // if curr_ctxt: - // curr_ctxt.Exit() - - // # Make acc available in the forward context. - // enter_acc = self.forward_context.AddValue(acc) - - // # Add the stack_push op in the context of value.op. - // swap_enabled = self.forward_context.swap_memory - // value_ctxt = util.GetOutputContext(value.op) - // if value_ctxt == self.forward_context: - // # value is not nested in the forward context. - // self.forward_context.Enter() - // push = gen_data_flow_ops.stack_push_v2( - // enter_acc, value, swap_memory=swap_enabled) - // self.forward_context.Exit() - // # Protect stack push and order it before forward_index. - // self.forward_index.op._add_control_input(push.op) - // else: - // # value is in a cond context within the forward context. - // if not isinstance(value_ctxt, CondContext): - // raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) - // if dead_branch: - // # The special case for creating a zero tensor for a dead - // # branch of a switch. See ControlFlowState.ZerosLike(). - // value_ctxt.outer_context.Enter() - // push = gen_data_flow_ops.stack_push_v2( - // enter_acc, value, swap_memory=swap_enabled) - // value_ctxt.outer_context.Exit() - // push.op._set_control_flow_context(value_ctxt) - // else: - // value_ctxt.Enter() - // push = gen_data_flow_ops.stack_push_v2( - // enter_acc, value, swap_memory=swap_enabled) - // value_ctxt.Exit() - // # Protect stack push and order it before forward_sync. - // self.forward_sync._add_control_input(push.op) - // # Order stack push after the successor of forward_index - // add_op = self.forward_index.op.inputs[0].op - // push.op._add_control_input(add_op) - // return acc + using (_forward_index.graph.as_default()) + { + var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); + return tf_with(ops.control_dependencies(null), delegate + { + Tensor acc = null; + Tensor push = null; + if (curr_ctxt != null) + curr_ctxt.Enter(); + ops.colocate_with(value); + { + // We only need to pass maximum_iterations to the stack if + // we're inside an XLA context. + var max_size = constant_op.constant(-1, dtypes.int32); + acc = gen_data_flow_ops.stack_v2( + max_size: max_size, elem_type: value.dtype.as_base_dtype(), name: "f_acc"); + } + if (curr_ctxt != null) + curr_ctxt.Exit(); + + // Make acc available in the forward context. + var enter_acc = forward_context.AddValue(acc); + + // Add the stack_push op in the context of value.op. + var swap_enabled = forward_context.swap_memory; + var value_ctxt = util.GetOutputContext(value.op); + if(value_ctxt == forward_context) + { + // value is not nested in the forward context. + forward_context.Enter(); + push = gen_data_flow_ops.stack_push_v2(enter_acc, value, swap_memory: swap_enabled); + forward_context.Exit(); + // Protect stack push and order it before forward_index. + forward_index.op._add_control_input(push.op); + } + else + { + throw new NotImplementedException("AddForwardAccumulator"); + } + + // Order stack push after the successor of forward_index + var add_op = forward_index.op.inputs[0].op; + push.op._add_control_input(add_op); + return acc; + }); + } } // """Add the getter for an accumulated value in the grad context. @@ -225,6 +217,7 @@ namespace Tensorflow.Operations.ControlFlows // Returns: // The current value (the top of the stack). // """ + public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) { throw new NotImplementedException(); @@ -261,62 +254,50 @@ namespace Tensorflow.Operations.ControlFlows // return pop } - // def GetRealValue(self, value): - // """Get the real value of `value`. - - // If backprop "uses" a value produced by forward inference, an accumulator - // is added in the forward loop to accumulate its values. We use the - // accumulated value. This method must be called in the grad loop context. - // `value` must be in forward and needed for backprop. - - // Args: - // value: A tensor to be captured. - - // Returns: - // The same tensor obtained from the saved history. - // """ - // assert value.op.type not in ["Variable", "VariableV2"] - // real_value = self._history_map.get(value.name) - // if real_value is None: - // cur_value = value - // cur_grad_state = self - // while True: - // enter_op = util.GetLoopConstantEnter(cur_value) - // if enter_op: - // # Special case: cur_value comes from a constant Enter node. - // cur_value = enter_op.inputs[0] - // cur_grad_state = cur_grad_state.outer_grad_state - // if cur_grad_state is None: - // # We are now outside all nested loops for this gradient(), - // # so `value` is a loop invariant and there is no need to - // # save the history of value. Just make cur_value to enter - // # the right control flow context. - // real_value = self._grad_context.AddValue(cur_value) - // break - // elif constant_op.is_constant(cur_value): - // # If the value to be forwarded is a constant, clone the constant in - // # the gradient loop rather than using a stack. - // # TODO(phawkins): consider hoisting the constant out of the loop - // # instead. - // real_value = constant_op.constant( - // tensor_util.constant_value(cur_value), dtype=cur_value.dtype) - // break - // else: - // # Record the history of this value in forward_ctxt. - // self._grad_context.Exit() - // history_value = cur_grad_state.AddForwardAccumulator(cur_value) - // self._grad_context.Enter() - // break - - // if real_value is None: - // # Add the stack pop op in the grad context. - // real_value = cur_grad_state.AddBackpropAccumulatedValue( - // history_value, cur_value) - // if cur_grad_state != self: - // real_value = self._grad_context.AddValue(real_value) - // self._history_map[value.name] = real_value - // return real_value - - + /// + /// Get the real value of `value`. + /// + /// A tensor to be captured. + /// The same tensor obtained from the saved history. + public Tensor GetRealValue(Tensor value) + { + Tensor real_value = null; + if(real_value == null) + { + var cur_value = value; + var cur_grad_state = this; + Tensor history_value = null; + while (true) + { + var enter_op = util.GetLoopConstantEnter(cur_value); + if(enter_op != null) + { + throw new NotImplementedException("GetRealValue"); + } + else if (constant_op.is_constant(cur_value)) + { + throw new NotImplementedException("GetRealValue"); + } + else + { + // Record the history of this value in forward_ctxt. + _grad_context.Exit(); + history_value = cur_grad_state.AddForwardAccumulator(cur_value); + _grad_context.Enter(); + break; + } + } + + if(real_value == null) + { + // Add the stack pop op in the grad context. + real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value, cur_value); + if (cur_grad_state != this) + real_value = _grad_context.AddValue(real_value); + } + _history_map[value.name] = real_value; + } + return real_value; + } } } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 56bcf897..02a5a573 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -530,10 +530,9 @@ namespace Tensorflow.Operations } if(forward_ctxt == grad_ctxt.grad_state.forward_context) { - throw new NotImplementedException("forward_ctxt == grad_ctxt.grad_state.forward_context"); - /*real_val = grad_ctxt.grad_state.GetRealValue(val); + var real_val = grad_ctxt.grad_state.GetRealValue(val); _external_values[val.name] = real_val; - return real_val;*/ + return real_val; } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index a8a0e0b9..48af7d58 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -30,7 +30,7 @@ namespace Tensorflow.Operations TF_DataType dtype = TF_DataType.DtInvalid, int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) { - tf_with(tf.variable_scope("rnn"), scope => + return tf_with(tf.variable_scope("rnn"), scope => { VariableScope varscope = scope; var flat_input = nest.flatten(inputs_tensor); @@ -64,9 +64,12 @@ namespace Tensorflow.Operations swap_memory: swap_memory, sequence_length: sequence_length, dtype: dtype); - }); - throw new NotImplementedException(""); + if (!time_major) + outputs = nest.map_structure(_transpose_batch_time, outputs); + + return (outputs, final_state); + }); } /// @@ -210,16 +213,28 @@ namespace Tensorflow.Operations var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t); // Keras RNN cells only accept state as list, even if it's a single tensor. // var is_keras_rnn_cell = _is_keras_rnn_cell(cell); - (Tensor, Tensor) a = (null, null); + Tensor[] outputs = null; if (sequence_length != null) throw new NotImplementedException("sequence_length != null"); else - a = cell.__call__(input_t_t, state: state1); + outputs = cell.__call__(input_t_t, state: state1); + + var (output, new_state) = (outputs[0], outputs[1]); + // Keras cells always wrap state as list, even if it's a single tensor. + // if(is_keras_rnn_cell && len(new_state)) == 1 + // Pack state if using state tuples + outputs = nest.flatten2(output).Select(x => x as Tensor).ToArray(); - return item; + output_ta_t = zip(output_ta_t, outputs).Select(x => + { + var(ta, @out) = (x.Item1, x.Item2); + return ta.write(item.time, @out); + }).ToArray(); + + return new BodyItemInRnnWhileLoop(item.time + 1, output_ta_t, new_state); }; - control_flow_ops.while_loop( + var while_loop_result = control_flow_ops.while_loop( cond: cond, body: _time_step, loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), @@ -227,7 +242,18 @@ namespace Tensorflow.Operations maximum_iterations: time_steps, swap_memory: swap_memory); - throw new NotImplementedException(""); + (_, TensorArray[] output_final_ta, Tensor final_state) = (while_loop_result.time, while_loop_result.output_ta_t, while_loop_result.state); + + // Unpack final output if not using output tuples. + var final_outputs = output_final_ta.Select(ta => ta.stack()).ToArray(); + // Restore some shape information + foreach (var (output, output_size) in zip(final_outputs, flat_output_size)) + { + var shape = rnn_cell_impl._concat(new[] { const_time_steps, const_batch_size }, output_size, @static: true); + output.set_shape(shape); + } + + return (final_outputs[0], final_state); } private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs index 3164ba14..cf5f1ce0 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs @@ -53,5 +53,34 @@ namespace Tensorflow.Operations return array_ops.concat(new[] { p, s }, 0); } } + + public static TensorShape _concat(int[] prefix, int suffix, bool @static = false) + { + var p = new TensorShape(prefix); + var p_static = prefix; + var p_tensor = p.is_fully_defined() ? constant_op.constant(p.as_list(), dtype: dtypes.int32) : null; + + var s_tensor_shape = new TensorShape(suffix); + var s_static = s_tensor_shape.ndim > -1 ? + s_tensor_shape.dims : + null; + var s_tensor = s_tensor_shape.is_fully_defined() ? + constant_op.constant(s_tensor_shape.dims, dtype: dtypes.int32) : + null; + + if (@static) + { + if (p_static is null) return null; + var shape = new TensorShape(p_static).concatenate(s_static); + return shape; + } + else + { + if (p is null || s_tensor is null) + throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}"); + // return array_ops.concat(new[] { p_tensor, s_tensor }, 0); + throw new NotImplementedException(""); + } + } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 9f0cb9a5..c9ae7071 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -52,6 +52,10 @@ namespace Tensorflow public void _set_control_flow_context(ControlFlowContext ctx) { + if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc")) + { + + } _control_flow_context = ctx; } @@ -59,5 +63,10 @@ namespace Tensorflow { return _control_flow_context; } + + public WhileContext GetWhileContext() + { + return _control_flow_context as WhileContext; + } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Instance.cs b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs index 6f6c8226..e39a34a3 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Instance.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs @@ -15,17 +15,14 @@ ******************************************************************************/ using System; +using System.Linq; using System.Collections.Generic; +using static Tensorflow.Binding; namespace Tensorflow { public partial class Operation { - // cache the mapping between managed and unmanaged op - // some data is stored in managed instance, so when - // create Operation by IntPtr, it will lost some data. - private static Dictionary OpInstances = new Dictionary(); - /// /// Get operation by handle /// @@ -33,9 +30,17 @@ namespace Tensorflow /// public Operation GetOperation(IntPtr handle) { - return OpInstances.ContainsKey(handle) ? - OpInstances[handle] : - new Operation(handle); + var nodes = tf.get_default_graph()._nodes_by_name; + foreach(var node in nodes.Values) + { + if (node is Operation op) + { + if (op == handle) + return op; + } + } + + return null; } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index d5068f2e..e8eb216f 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -106,7 +106,6 @@ namespace Tensorflow _control_flow_context = _graph._get_control_flow_context(); // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. - OpInstances[_handle] = this; } /*public Operation(Graph g, string opType, string oper_name) @@ -183,10 +182,12 @@ namespace Tensorflow // This will be set by self.inputs. if (op_def == null) op_def = g.GetOpDef(node_def.Op); - + if(node_def.Name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc") + { + + } var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); - _is_stateful = op_def.IsStateful; // Initialize self._outputs. @@ -202,8 +203,6 @@ namespace Tensorflow if (_handle != IntPtr.Zero) _control_flow_post_processing(); - - OpInstances[_handle] = this; } public void run(FeedItem[] feed_dict = null, Session session = null) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 01231035..cea3e440 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -183,7 +183,7 @@ namespace Tensorflow { var _op = _op_def_lib._apply_op_helper("Identity", name, new { input }); - return _op.outputs[0]; + return _op.output; } public static Tensor invert_permutation(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs index 5f0ceb48..8f9c8120 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Operations; + namespace Tensorflow { public class gen_control_flow_ops @@ -148,18 +150,18 @@ namespace Tensorflow return new []{_op.outputs[0], _op.outputs[1]}; } - public static Tensor[] ref_merge(Tensor[] inputs, string name = null) + public static MergeOutput ref_merge(Tensor[] inputs, string name = null) { var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); - return _op.outputs; + return new MergeOutput(_op.outputs); } - public static Tensor[] merge(Tensor[] inputs, string name = null) + public static MergeOutput merge(Tensor[] inputs, string name = null) { var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); - return _op.outputs; + return new MergeOutput(_op.outputs); } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index 52b0a372..fcb1000f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -259,5 +259,31 @@ namespace Tensorflow return _op.output; } + + public static Tensor stack_v2(Tensor max_size, TF_DataType elem_type, string stack_name = "", + string name = null) + { + var _op = _op_def_lib._apply_op_helper("StackV2", name, new + { + max_size, + elem_type, + stack_name + }); + + return _op.output; + } + + public static Tensor stack_push_v2(Tensor handle, Tensor elem, bool swap_memory = false, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("StackPushV2", name, new + { + handle, + elem, + swap_memory + }); + + return _op.output; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 08431089..7e54349f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -282,7 +282,7 @@ namespace Tensorflow /// /// /// - public static Tensor tanh_grad(Tensor y, Tensor dy, string name = "TanhGrad") + public static Tensor tanh_grad(Tensor y, Tensor dy, string name = null) => _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; public static Tensor floor(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 54149fe1..7dbacea0 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -526,6 +526,14 @@ namespace Tensorflow.Util return pack_sequence_as(structure, mapped_flat_structure) as Tensor; } + public static Tensor map_structure2(Func func, T structure) + { + var flat_structure = flatten(structure); + var mapped_flat_structure = flat_structure.Select(func).ToList(); + + return pack_sequence_as(structure, mapped_flat_structure) as Tensor; + } + /// /// Same as map_structure, but with only one structure (no combining of multiple structures) /// From 3e0737285558d88ed6ca46fbca3e211ea23a9f35 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 07:22:30 -0600 Subject: [PATCH 34/41] WhileContext AddBackpropAccumulator --- .../Operations/ControlFlows/WhileContext.cs | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 02a5a573..fa7a77a6 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -446,6 +446,84 @@ namespace Tensorflow.Operations return (total_iterations, next_n); } + /// + /// Add an accumulation loop for every loop invariant. + /// + /// The Enter op for a loop invariant. + /// The partial gradient of an iteration for a loop invariant. + /// The gradient for a loop invariant. + public Tensor AddBackpropAccumulator(Operation op, Tensor grad) + { + Tensor acc = null; + Exit(); + // Create a zeros tensor with the right shape for acc. If we don't + // know the full shape statically, we will have to get the shape + // dynamically from the forward inference. Getting the shape right + // for the zeros is only needed for the base case when the loop exits + // without running any iterations. + var shape = grad.TensorShape; + if (shape.is_fully_defined()) + { + if (outer_context != null) + outer_context.Enter(); + acc = constant_op.constant(0, grad.dtype, shape: shape, name: "b_acc"); + if (outer_context != null) + outer_context.Exit(); + } + else + { + var value = op.inputs[0]; + if(outer_context is WhileContext wc) + { + // We are in a nested while loop. + var forward_ctxt = grad_state.forward_context; + forward_ctxt.outer_context.Enter(); + var zeros_shape = array_ops.shape_internal(value, optimize: false); + forward_ctxt.outer_context.Exit(); + var outer_grad_state = grad_state.outer_grad_state; + var history_zeros_shape = outer_grad_state.AddForwardAccumulator(zeros_shape); + outer_context.Enter(); + var real_shape = outer_grad_state.AddBackpropAccumulatedValue( + history_zeros_shape, zeros_shape); + acc = array_ops.zeros(real_shape, grad.dtype); + outer_context.Exit(); + } + else + { + if (outer_context != null) + outer_context.Enter(); + var zeros_shape = array_ops.shape_internal(value, optimize: false); + acc = array_ops.zeros(zeros_shape, grad.dtype); + if (outer_context != null) + outer_context.Exit(); + } + throw new NotImplementedException("AddBackpropAccumulator"); + } + + Enter(); + AddName(acc.name); + var enter_acc = _Enter( + acc, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + name: "b_acc"); + loop_enters.append(enter_acc); + var merge_acc = merge(new[] { enter_acc, enter_acc }, name: "b_acc")[0]; + + var switch_result = @switch(merge_acc, _pivot); + var (switch_acc_false, switch_acc_true) = (switch_result[0], switch_result[1]); + + var add_acc = math_ops.add(switch_acc_true, grad); + var next_acc = _NextIteration(add_acc); + merge_acc.op._update_input(1, next_acc); + + var result_acc = exit(switch_acc_false, name: "b_acc"); + loop_exits.append(result_acc); + ExitResult(new[] { result_acc }); + return result_acc; + } + /// /// Add the backprop loop that controls the iterations. /// From 07ddb74738849a2d59940af4d5b423ed5932f5dd Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 07:22:41 -0600 Subject: [PATCH 35/41] GradLoopState.AddBackpropAccumulatedValue --- .../Operations/ControlFlows/GradLoopState.cs | 102 ++++++++++++------ 1 file changed, 67 insertions(+), 35 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs index 143aacb1..2552df8a 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs @@ -78,6 +78,26 @@ namespace Tensorflow.Operations.ControlFlows /// public int pending_exits_count { get; set; } + Operation _grad_sync; + public Operation grad_sync + { + get + { + if(_grad_sync == null) + { + tf_with(ops.control_dependencies(null), delegate + { + _grad_sync = gen_control_flow_ops.control_trigger(name: "b_sync"); + }); + _grad_sync._set_control_flow_context(_grad_context); + _grad_index.op._add_control_input(_grad_sync); + if (_grad_context.outer_context != null) + _grad_context.outer_context.AddInnerOp(_grad_sync); + } + return _grad_sync; + } + } + public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) { // Information needed by backprop. @@ -155,7 +175,7 @@ namespace Tensorflow.Operations.ControlFlows /// The stack that contains the accumulated history of the tensor. public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) { - using (_forward_index.graph.as_default()) + _forward_index.graph.as_default(); { var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); return tf_with(ops.control_dependencies(null), delegate @@ -220,38 +240,33 @@ namespace Tensorflow.Operations.ControlFlows public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) { - throw new NotImplementedException(); - // history_ctxt = history_value.op._get_control_flow_context() - // # Find the cond context that controls history_value if any. - // cond_ctxt = None - // value_ctxt = value.op._get_control_flow_context() - // while value_ctxt and value_ctxt != history_ctxt: - // if isinstance(value_ctxt, CondContext): - // cond_ctxt = value_ctxt - // break - // value_ctxt = value_ctxt.outer_context - // with ops.control_dependencies(None): - // self.grad_context.Enter() - // if cond_ctxt: - // # Guard stack pop with a switch if it is controlled by a cond. - // grad_state = self - // pred = None - // while pred is None and grad_state: - // pred = grad_state.history_map.get(cond_ctxt.pred.name) - // grad_state = grad_state.outer_grad_state - // if pred is None: - // pred = cond_ctxt.pred - // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch - // history_value = _SwitchRefOrTensor(history_value, pred)[branch] - // pop = gen_data_flow_ops.stack_pop_v2(history_value, - // value.dtype.base_dtype) - // pop.set_shape(value.get_shape()) - // self.grad_context.Exit() - // parallel_iterations = self.grad_context.parallel_iterations - // if parallel_iterations > 1: - // # All pops are ordered after pivot_for_body and before grad_sync. - // self.grad_sync._add_control_input(pop.op) - // return pop + var history_ctxt = history_value.op._get_control_flow_context(); + // Find the cond context that controls history_value if any. + CondContext cond_ctxt = null; + Tensor pop = null; + var value_ctxt = value.op._get_control_flow_context(); + while(value_ctxt != null && value_ctxt != history_ctxt) + { + if (value_ctxt is CondContext cc) + cond_ctxt = cc; + value_ctxt = value_ctxt.outer_context; + } + tf_with(ops.control_dependencies(null), delegate + { + grad_context.Enter(); + if(cond_ctxt != null) + { + throw new NotImplementedException("AddBackpropAccumulatedValue"); + } + pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype()); + pop.set_shape(value.TensorShape); + grad_context.Exit(); + }); + var parallel_iterations = grad_context.parallel_iterations; + if (parallel_iterations > 1) + // All pops are ordered after pivot_for_body and before grad_sync. + grad_sync._add_control_input(pop.op); + return pop; } /// @@ -272,11 +287,28 @@ namespace Tensorflow.Operations.ControlFlows var enter_op = util.GetLoopConstantEnter(cur_value); if(enter_op != null) { - throw new NotImplementedException("GetRealValue"); + // Special case: cur_value comes from a constant Enter node. + cur_value = enter_op.inputs[0]; + cur_grad_state = cur_grad_state.outer_grad_state; + if(cur_grad_state == null) + { + // We are now outside all nested loops for this gradient(), + // so `value` is a loop invariant and there is no need to + // save the history of value. Just make cur_value to enter + // the right control flow context. + real_value = _grad_context.AddValue(cur_value); + break; + } } else if (constant_op.is_constant(cur_value)) { - throw new NotImplementedException("GetRealValue"); + // We are now outside all nested loops for this gradient(), + // so `value` is a loop invariant and there is no need to + // save the history of value. Just make cur_value to enter + // the right control flow context. + real_value = constant_op.constant( + tensor_util.constant_value(cur_value), dtype: cur_value.dtype); + break; } else { From 57f77ab233dd74b64c1549bfe71bbd6f694e133a Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 07:23:12 -0600 Subject: [PATCH 36/41] data_flow_ops.stack_pop_v2 --- .../Operations/gen_data_flow_ops.cs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index fcb1000f..65b86f04 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -285,5 +285,16 @@ namespace Tensorflow return _op.output; } + + public static Tensor stack_pop_v2(Tensor handle, TF_DataType elem_type, string name = null) + { + var _op = _op_def_lib._apply_op_helper("StackPopV2", name, new + { + handle, + elem_type + }); + + return _op.output; + } } } From 3377f12b352edadfba23d680e021ca65e0b60a13 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 07:23:32 -0600 Subject: [PATCH 37/41] control_flow_ops.control_trigger --- .../Operations/gen_control_flow_ops.cs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs index 8f9c8120..0a2d82d7 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs @@ -22,6 +22,15 @@ namespace Tensorflow { public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + public static Operation control_trigger(string name = null) + { + var _op = _op_def_lib._apply_op_helper("ControlTrigger", name, new + { + }); + + return _op; + } + /// /// Creates or finds a child frame, and makes `data` available to the child frame. /// From 59cbca5c1703ed5fe607ffa11cb8e54edbe9fdbb Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 10:41:10 -0600 Subject: [PATCH 38/41] ControlFlowState.PostProcessing --- .../ControlFlows/ControlFlowState.cs | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs index 1d296774..9351cab4 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs @@ -290,7 +290,35 @@ namespace Tensorflow.Operations.ControlFlows public void PostProcessing() { - throw new NotImplementedException("PostProcessing"); + foreach(var grad_state in _map.Values) + { + foreach(var b_merge in grad_state.switch_map.Values) + { + if(b_merge.op.inputs[0] == b_merge.op.inputs[1]) + { + Tensor next_grad_val = null; + // The value of this loop variable at iteration i+1 doesn't + // depend on its value at iteration i. So use zeros as the + // gradients for all iterations > 0. + var dtype = b_merge.op.inputs[0].dtype; + var shape = b_merge.op.inputs[0].TensorShape; + if (shape.is_fully_defined()) + { + grad_state.grad_context.Enter(); + // Create a zeros and use it for iterations > 0. + var grad_val = constant_op.constant(0, dtype: dtype, shape: shape); + next_grad_val = control_flow_ops._NextIteration(grad_val); + grad_state.grad_context.Exit(); + } + else + { + throw new NotImplementedException("PostProcessing shape is not fully defined."); + } + + b_merge.op._update_input(1, next_grad_val); + } + } + } } } } From fcd2cd6573ee1608c092ed9092769c67f0bb19b1 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 10:41:30 -0600 Subject: [PATCH 39/41] nn_ops.in_top_kv2 --- src/TensorFlowNET.Core/APIs/tf.nn.cs | 2 +- .../Operations/NnOps/gen_nn_ops.cs | 22 ++++++++++++++++++- src/TensorFlowNET.Core/Operations/nn_ops.cs | 8 +++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 5b5786d1..64d47acd 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -134,7 +134,7 @@ namespace Tensorflow => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK") - => gen_ops.in_top_k(predictions, targets, k, name); + => nn_ops.in_top_k(predictions, targets, k, name); public Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index f3a63d68..fbc68dbf 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -244,7 +244,27 @@ namespace Tensorflow.Operations logits }); - return _op.outputs[0]; + return _op.output; + } + + /// + /// Says whether the targets are in the top `K` predictions. + /// + /// + /// + /// + /// + /// A `Tensor` of type `bool`. + public static Tensor in_top_kv2(Tensor predictions, Tensor targets, int k, string name = null) + { + var _op = _op_def_lib._apply_op_helper("InTopKV2", name: name, args: new + { + predictions, + targets, + k + }); + + return _op.output; } public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index 7ae1f3a9..124fd72b 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -111,6 +111,14 @@ namespace Tensorflow return noise_shape; } + public static Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = null) + { + return tf_with(ops.name_scope(name, "in_top_k"), delegate + { + return gen_nn_ops.in_top_kv2(predictions, targets, k, name: name); + }); + } + public static Tensor log_softmax(Tensor logits, int axis = -1, string name = null) { return _softmax(logits, gen_nn_ops.log_softmax, axis, name); From ad250d0c796b12e496ee3a18998d46da48284547 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 10:42:17 -0600 Subject: [PATCH 40/41] implement _SwitchGrad when merge_grad is not null. --- .../Gradients/control_flow_grad.cs | 19 ++- .../Gradients/gradients_util.cs | 111 +++++++++--------- .../Operations/Operation.cs | 5 +- 3 files changed, 67 insertions(+), 68 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs index acaa6de3..3ae890fb 100644 --- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs @@ -48,7 +48,12 @@ namespace Tensorflow.Gradients { var merge_grad = grad_ctxt.grad_state.switch_map.get(op); if (merge_grad != null) - throw new NotImplementedException("_SwitchGrad merge_grad != null"); + { + if (grads[1] != null) + control_flow_ops._AddNextAndBackEdge(merge_grad, grads[1], + enforce_shape_invariant: false); + return new Tensor[] { null, null }; + } else if (grads[0] != null) { merge_grad = merge(new[] { grads[0], grads[0] }, name: "b_switch")[0]; @@ -233,17 +238,9 @@ namespace Tensorflow.Gradients return grads; if (op.get_attr("is_constant")) { - throw new NotImplementedException("_EnterGrad is_constant"); - // Add a gradient accumulator for each loop invariant. - // if isinstance(grad, ops.Tensor) : - // result = grad_ctxt.AddBackpropAccumulator(op, grad) - // elif isinstance(grad, ops.IndexedSlices) : - // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) - // else: - // # TODO(yuanbyu, lukasr): Add support for SparseTensor. - // raise TypeError("Type %s not supported" % type(grad)) + // Add a gradient accumulator for each loop invariant. + result = grad_ctxt.AddBackpropAccumulator(op, grad); } - else { result = control_flow_ops.exit(grad); diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 8170ea6f..55b771ed 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -123,10 +123,7 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); - if(op.name == "rnn/while/Exit") - { - } _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); { if (loop_state != null) @@ -136,6 +133,7 @@ namespace Tensorflow loop_state.ExitGradWhileContext(op, before: true); Tensor[] in_grads = null; + Func grad_fn = null; var is_partitioned_call = _IsPartitionedCall(op); var is_func_call = false; var has_out_grads = out_grads.Exists(x => x != null); @@ -143,8 +141,6 @@ namespace Tensorflow { // A grad_fn must be defined, either as a function or as None // for ops that do not have gradients. - - Func grad_fn = null; try { grad_fn = ops.get_gradient_function(op); @@ -167,61 +163,57 @@ namespace Tensorflow throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); } } + } - if (loop_state != null) - loop_state.EnterGradWhileContext(op, before: false); + if (loop_state != null) + loop_state.EnterGradWhileContext(op, before: false); - if ((is_func_call || grad_fn != null) && has_out_grads) + if ((is_func_call || grad_fn != null) && has_out_grads) + { + // NOTE: If _AggregatedGrads didn't compute a value for the i'th + // output, it means that the cost does not depend on output[i], + // therefore dC/doutput[i] is 0. + foreach (var (i, out_grad) in enumerate(out_grads)) { - // NOTE: If _AggregatedGrads didn't compute a value for the i'th - // output, it means that the cost does not depend on output[i], - // therefore dC/doutput[i] is 0. - foreach (var (i, out_grad) in enumerate(out_grads)) - { - if (out_grad == null && - (grad_fn == null || _IsTrainable(op.outputs[i]))) - { - // Only trainable outputs or outputs for a function call that - // will use SymbolicGradient get a zero gradient. Gradient - // functions should ignore the gradient for other outputs. - if (loop_state != null) - out_grads[i] = new List { loop_state.ZerosLike(op, i) }; - else - out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; - } - } - - tf_with(ops.name_scope(op.name + "_grad"), scope1 => + if (out_grad == null && + (grad_fn == null || _IsTrainable(op.outputs[i]))) { - if (grad_fn != null) - { - in_grads = _MaybeCompile(grad_scope, - op, - out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), - null, - grad_fn); - } + // Only trainable outputs or outputs for a function call that + // will use SymbolicGradient get a zero gradient. Gradient + // functions should ignore the gradient for other outputs. + if (loop_state != null) + out_grads[i] = new List { loop_state.ZerosLike(op, i) }; else - { - throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); - } - _VerifyGeneratedGradients(in_grads, op); - if (gate_gradients && in_grads.Count(x => x != null) > 1) - { - ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); - in_grads = control_flow_ops.tuple(in_grads); - } - }); + out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; + } } - else + + tf_with(ops.name_scope(op.name + "_grad"), scope1 => { - // If no grad_fn is defined or none of out_grads is available, - // just propagate a list of None backwards. - in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; - } + if (grad_fn != null) + { + in_grads = _MaybeCompile(grad_scope, + op, + out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), + null, + grad_fn); + } + else + { + throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); + } + _VerifyGeneratedGradients(in_grads, op); + if (gate_gradients && in_grads.Count(x => x != null) > 1) + { + ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); + in_grads = control_flow_ops.tuple(in_grads); + } + }); } else { + // If no grad_fn is defined or none of out_grads is available, + // just propagate a list of None backwards. in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; } @@ -370,7 +362,16 @@ namespace Tensorflow grads[op.name] = op_grads; } var t_grads = op_grads[t.value_index]; - t_grads.Add(grad); + if (t_grads.Count == 0) + t_grads.Add(grad); + else + op_grads[t.value_index][0] = grad; + + /*if (control_flow_util.IsLoopSwitch(op) && + t_grads[0] == null) + op_grads[t.value_index] = new List { grad }; + else + t_grads.Add(grad);*/ } private static IEnumerable _NonEagerInputs(Operation op, Tensor[] xs) @@ -379,7 +380,8 @@ namespace Tensorflow yield return op.inputs[i]; } - private static List> _AggregatedGrads(Dictionary>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) + private static List> _AggregatedGrads(Dictionary>> grads, Operation op, string gradient_uid, + ControlFlowState loop_state, int aggregation_method = 0) { var out_grads = _GetGrads(grads, op); @@ -387,7 +389,10 @@ namespace Tensorflow { if (loop_state != null) { - + if (out_grads.Count > 1 && + out_grads[1].Count > 0 && + control_flow_util.IsLoopSwitch(op)) + continue; } // Aggregate multiple gradients, and convert [] to None. diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index e8eb216f..0f9ed2eb 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -182,10 +182,7 @@ namespace Tensorflow // This will be set by self.inputs. if (op_def == null) op_def = g.GetOpDef(node_def.Op); - if(node_def.Name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc") - { - - } + var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); _is_stateful = op_def.IsStateful; From 3090e45837da6a72baacd0859e0b369889bdc6ed Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 20:52:59 -0600 Subject: [PATCH 41/41] v0.12 released --- README.md | 7 +++++++ src/TensorFlowNET.Core/Gradients/gradients_util.cs | 11 +++-------- src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 3 ++- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 8744ba72..87e8042d 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,13 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr ### How to use +| TensorFlow | tf 1.13 | tf 1.14 | tf 1.15 | tf 2.0 | +| ----------- | ------- | ------- | ------- | ------ | +| tf.net 0.12 | | x | | | +| tf.net 0.11 | x | x | | | +| tf.net 0.10 | x | x | | | +| tf.net 0.9 | x | | | | + Install TF.NET and TensorFlow binary through NuGet. ```sh ### install tensorflow C# binding diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 55b771ed..c9322105 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -362,16 +362,11 @@ namespace Tensorflow grads[op.name] = op_grads; } var t_grads = op_grads[t.value_index]; - if (t_grads.Count == 0) - t_grads.Add(grad); - else + if (t_grads.Count > 0 && + control_flow_util.IsLoopSwitch(op)) op_grads[t.value_index][0] = grad; - - /*if (control_flow_util.IsLoopSwitch(op) && - t_grads[0] == null) - op_grads[t.value_index] = new List { grad }; else - t_grads.Add(grad);*/ + t_grads.Add(grad); } private static IEnumerable _NonEagerInputs(Operation op, Tensor[] xs) diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index fbad178e..de6e92cf 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -21,7 +21,8 @@ https://tensorflownet.readthedocs.io 0.12.0.0 Changes since v0.11.0: 1: Add ICanBeFlattened for nest.flatten2. -2: +2: Complete the WhileContext. +3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn. 7.3 0.12.0.0 LICENSE