From 098ffdfaf536c97fe6ac48f407382a52747efb66 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 13 Apr 2019 17:12:43 -0500 Subject: [PATCH 1/4] CondTestCases.testCondTrue passed --- test/TensorFlowNET.UnitTest/PythonTest.cs | 1 + .../control_flow_ops_test/CondTestCases.cs | 15 ++++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 5d9bb374..97d49932 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -136,6 +136,7 @@ namespace TensorFlowNET.UnitTest /// A Tensor or a nested list/tuple of Tensors. /// /// tensors numpy values. + [Obsolete("Why do we need this function? we already have Tensor.eval().")] public object evaluate(params Tensor[] tensors) { // if context.executing_eagerly(): diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index 85908baf..de391679 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -13,11 +13,16 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test [TestMethod] public void testCondTrue() { - var x = tf.constant(2); - var y = tf.constant(5); - var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), - () => tf.add(y, tf.constant(23))); - self.assertEquals(self.evaluate(z), 34); + with(tf.Session(), sess => + { + var x = tf.constant(2); + var y = tf.constant(5); + var z = control_flow_ops.cond(tf.less(x, y), + () => tf.multiply(x, tf.constant(17)), + () => tf.add(y, tf.constant(23))); + int result = z.eval(sess); + assertEquals(result, 34); + }); } [Ignore("Todo")] From 5627443b7efeedc914959f4ba620396d5e6ec90d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 13 Apr 2019 18:14:51 -0500 Subject: [PATCH 2/4] change tf.math.multiply, math.add to generic. --- src/TensorFlowNET.Core/APIs/tf.math.cs | 4 +- .../Operations/Operation.cs | 2 +- .../Operations/gen_math_ops.cs | 4 +- .../control_flow_ops_test/CondTestCases.cs | 52 +++++++++++++++---- 4 files changed, 46 insertions(+), 16 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index bbf240e3..fbf3dd00 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -27,7 +27,7 @@ namespace Tensorflow public static Tensor asin(Tensor x, string name = null) => gen_math_ops.asin(x, name); - public static Tensor add(Tensor a, Tensor b) + public static Tensor add(Tx a, Ty b) => gen_math_ops.add(a, b); /// @@ -251,7 +251,7 @@ namespace Tensorflow public static Tensor minimum(T1 x, T2 y, string name = null) => gen_math_ops.minimum(x, y, name: name); - public static Tensor multiply(Tensor x, Tensor y) + public static Tensor multiply(Tx x, Ty y) => gen_math_ops.mul(x, y); public static Tensor negative(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 9f4280d9..c81ab08c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -287,7 +287,7 @@ namespace Tensorflow // Reset cached inputs. _inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None // TODO: implement below code dependencies - //c_api.UpdateEdge(_graph._c_graph, output, input); + // c_api.TF_UpdateEdge(graph, output, input, status); } private void _assert_same_graph(Tensor tensor) diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 5e58df45..56477442 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -80,7 +80,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor add(Tensor x, Tensor y, string name = null) + public static Tensor add(Tx x, Ty y, string name = null) { var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); @@ -300,7 +300,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor mul(Tensor x, Tensor y, string name = null) + public static Tensor mul(Tx x, Ty y, string name = null) { var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index de391679..e35923e5 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; using Tensorflow; namespace TensorFlowNET.UnitTest.control_flow_ops_test @@ -18,25 +19,54 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test var x = tf.constant(2); var y = tf.constant(5); var z = control_flow_ops.cond(tf.less(x, y), - () => tf.multiply(x, tf.constant(17)), - () => tf.add(y, tf.constant(23))); + () => tf.multiply(x, 17), + () => tf.add(y, 23)); int result = z.eval(sess); assertEquals(result, 34); }); } - [Ignore("Todo")] [TestMethod] public void testCondFalse() { - // def testCondFalse(self): - // x = constant_op.constant(2) - // y = constant_op.constant(1) - // z = control_flow_ops.cond( - // math_ops.less( - // x, - // y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23)) - // self.assertEquals(self.evaluate(z), 24) + /* python + * import tensorflow as tf + from tensorflow.python.framework import ops + + def if_true(): + return tf.math.multiply(x, 17) + def if_false(): + return tf.math.add(y, 23) + + with tf.Session() as sess: + x = tf.constant(2) + y = tf.constant(1) + pred = tf.math.less(x,y) + z = tf.cond(pred, if_true, if_false) + result = z.eval() + + print(result == 24) */ + + with(tf.Session(), sess => + { + var x = tf.constant(2); + var y = tf.constant(1); + var pred = tf.less(x, y); + + Func if_true = delegate + { + return tf.multiply(x, 17); + }; + + Func if_false = delegate + { + return tf.add(y, 23); + }; + + var z = control_flow_ops.cond(pred, if_true, if_false); + int result = z.eval(sess); + assertEquals(result, 24); + }); } [Ignore("Todo")] From f4067f28f7699a6153f79cf5b0120b3e06212cff Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 13 Apr 2019 22:31:35 -0500 Subject: [PATCH 3/4] tf.train.import_meta_graph can import CondContext. --- .../Framework/meta_graph.py.cs | 21 +- src/TensorFlowNET.Core/Graphs/Graph.cs | 5 + .../Operations/ControlFlows/CondContext.cs | 48 +- .../ControlFlows/ControlFlowContext.cs | 34 +- .../Operations/Operation.cs | 2 +- .../Operations/control_flow_ops.py.cs | 8 +- .../Protobuf/ControlFlow.cs | 1172 +++++++++++++++++ src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs | 8 +- src/TensorFlowNET.Core/Protobuf/README.md | 7 +- .../Variables/RefVariable.cs | 4 +- src/TensorFlowNET.Core/ops.py.cs | 2 +- .../control_flow_ops_test/CondTestCases.cs | 21 +- 12 files changed, 1300 insertions(+), 32 deletions(-) create mode 100644 src/TensorFlowNET.Core/Protobuf/ControlFlow.cs diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs index c7af7051..ceebdc6e 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; +using Tensorflow.Operations; using static Tensorflow.CollectionDef; using static Tensorflow.MetaGraphDef.Types; @@ -95,15 +96,29 @@ namespace Tensorflow } else { - throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); + foreach(var value in col.Value.BytesList.Value) + { + switch (col.Key) + { + case "cond_context": + var proto = CondContextDef.Parser.ParseFrom(value); + var condContext = new CondContext().from_proto(proto, import_scope); + graph.add_to_collection(col.Key, condContext); + break; + default: + throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); + } + } } break; + default: + throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); } } - var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, - scope: scope_to_prepend_to_names) as List; + var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, + scope: scope_to_prepend_to_names); var var_list = new Dictionary(); variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 081893c2..f1a33371 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -412,6 +412,11 @@ namespace Tensorflow return _collections.ContainsKey(name) ? _collections[name] : null; } + public List get_collection(string name, string scope = null) + { + return _collections.ContainsKey(name) ? _collections[name] as List : new List(); + } + public object get_collection_ref(string name) { if (!_collections.ContainsKey(name)) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 1bfa81f2..c00e2c0e 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Operations /// /// The context for the conditional construct. /// - public class CondContext : ControlFlowContext + public class CondContext : ControlFlowContext, IProtoBuf { @@ -35,16 +35,20 @@ namespace Tensorflow.Operations /// Name of the `CondContext` python object. /// /// - public CondContext(Tensor pred, - Tensor pivot, - int branch, + public CondContext(Tensor pred = null, + Tensor pivot = null, + int? branch = null, string name = "cond_text", - object context_def = null, + CondContextDef context_def = null, string import_scope = null) { + if (pred == null && context_def == null) return; + _name = ops.get_default_graph().unique_name(name); - if (context_def != null) - throw new NotImplementedException("CondContext context_def is not null"); + if (context_def != null) + { + _init_from_proto(context_def, import_scope: import_scope); + } else { // Initializes the default fields. @@ -61,6 +65,18 @@ namespace Tensorflow.Operations } } + private void _init_from_proto(CondContextDef context_def, string import_scope = null) + { + var g = ops.get_default_graph(); + _name = ops.prepend_name_scope(context_def.ContextName, import_scope); + var p1 = ops.prepend_name_scope(context_def.PredName, import_scope); + _pred = g.as_graph_element(p1) as Tensor; + var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope); + _pivot = g.as_graph_element(p2) as Tensor; + _branch = context_def.Branch; + __init__(values_def: context_def.ValuesDef, import_scope: import_scope); + } + /// /// Add `val` to the current context and its outer context recursively. /// @@ -230,6 +246,22 @@ namespace Tensorflow.Operations public override void AddInnerOp(Operation resultOp) { throw new NotImplementedException(); - } + } + + public CondContextDef to_proto(string export_scope) + { + throw new NotImplementedException(); + } + + public CondContext from_proto(CondContextDef proto, string import_scope) + { + var ret = new CondContext(context_def: proto, import_scope: import_scope); + + ret.Enter(); + foreach (var nested_def in proto.NestedContexts) + throw new NotImplementedException(""); + ret.Exit(); + return ret; + } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index fef79c8d..86452e50 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -32,6 +32,8 @@ namespace Tensorflow.Operations protected Stack _context_stack; protected IControlFlowContext _outer_context; + protected Dictionary _external_values; + public ControlFlowContext() { _context_stack = new Stack(); @@ -40,15 +42,43 @@ namespace Tensorflow.Operations public string name { get => _name; } protected string _name; - public void __init__() + public void __init__(ValuesDef values_def = null, string import_scope = null) { - + _outer_context = ops.get_default_graph()._get_control_flow_context(); + if (values_def != null) + _init_values_from_proto(values_def, import_scope: import_scope); } public void __enter__() { } + /// + /// Initializes values and external_values from `ValuesDef` protocol buffer. + /// + /// + /// + protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null) + { + _external_values = new Dictionary(); + foreach (var value in values_def.Values) + _values.Add(value); + var g = ops.get_default_graph(); + foreach(var value in values_def.ExternalValues) + { + var k = ops.prepend_name_scope(value.Key, import_scope); + var v = value.Value; + _external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope)); + } + + var op_names = _values.Where(x => !_external_values.ContainsKey(x)) + .Select(x => x.Split(':')[0]) + .ToArray(); + + foreach (var op in op_names) + (g.as_graph_element(op) as Operation)._set_control_flow_context(this); + } + public void __exit__() { } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index c81ab08c..7357f49f 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -287,7 +287,7 @@ namespace Tensorflow // Reset cached inputs. _inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None // TODO: implement below code dependencies - // c_api.TF_UpdateEdge(graph, output, input, status); + c_api.TF_UpdateEdge(graph, output, input, status); } private void _assert_same_graph(Tensor tensor) diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index aebcfaef..fad7a1e1 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -330,7 +330,7 @@ namespace Tensorflow tensor.op.graph.prevent_fetching(tensor.op); // Build the graph for the true branch in a new context. - var context_t = new CondContext(pred, pivot_1, branch: 1); + var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); ITensorOrOperation orig_res_t; Tensor res_t; try @@ -343,7 +343,7 @@ namespace Tensorflow context_t.Exit(); } // Build the graph for the false branch in a new context. - var context_f = new CondContext(pred, pivot_2, branch: 0); + var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); ITensorOrOperation orig_res_f; Tensor res_f; try @@ -411,13 +411,13 @@ namespace Tensorflow tensor.op.graph.prevent_fetching(tensor.op); // Build the graph for the true branch in a new context. - var context_t = new CondContext(pred, pivot_1, branch: 1); + var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); context_t.Enter(); var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); context_t.Exit(); // Build the graph for the false branch in a new context. - var context_f = new CondContext(pred, pivot_2, branch: 0); + var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); context_f.Enter(); var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); context_f.Exit(); diff --git a/src/TensorFlowNET.Core/Protobuf/ControlFlow.cs b/src/TensorFlowNET.Core/Protobuf/ControlFlow.cs new file mode 100644 index 00000000..108bcc45 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/ControlFlow.cs @@ -0,0 +1,1172 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/control_flow.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/control_flow.proto + public static partial class ControlFlowReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/control_flow.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ControlFlowReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cit0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29udHJvbF9mbG93LnByb3Rv", + "Egp0ZW5zb3JmbG93IpYBCglWYWx1ZXNEZWYSDgoGdmFsdWVzGAEgAygJEkIK", + "D2V4dGVybmFsX3ZhbHVlcxgCIAMoCzIpLnRlbnNvcmZsb3cuVmFsdWVzRGVm", + "LkV4dGVybmFsVmFsdWVzRW50cnkaNQoTRXh0ZXJuYWxWYWx1ZXNFbnRyeRIL", + "CgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAk6AjgBIoMBChVDb250cm9sRmxv", + "d0NvbnRleHREZWYSLwoJY29uZF9jdHh0GAEgASgLMhoudGVuc29yZmxvdy5D", + "b25kQ29udGV4dERlZkgAEjEKCndoaWxlX2N0eHQYAiABKAsyGy50ZW5zb3Jm", + "bG93LldoaWxlQ29udGV4dERlZkgAQgYKBGN0eHQixAEKDkNvbmRDb250ZXh0", + "RGVmEhQKDGNvbnRleHRfbmFtZRgBIAEoCRIRCglwcmVkX25hbWUYAiABKAkS", + "EgoKcGl2b3RfbmFtZRgDIAEoCRIOCgZicmFuY2gYBCABKAUSKQoKdmFsdWVz", + "X2RlZhgFIAEoCzIVLnRlbnNvcmZsb3cuVmFsdWVzRGVmEjoKD25lc3RlZF9j", + "b250ZXh0cxgGIAMoCzIhLnRlbnNvcmZsb3cuQ29udHJvbEZsb3dDb250ZXh0", + "RGVmIvUCCg9XaGlsZUNvbnRleHREZWYSFAoMY29udGV4dF9uYW1lGAEgASgJ", + "EhsKE3BhcmFsbGVsX2l0ZXJhdGlvbnMYAiABKAUSEQoJYmFja19wcm9wGAMg", + "ASgIEhMKC3N3YXBfbWVtb3J5GAQgASgIEhIKCnBpdm90X25hbWUYBSABKAkS", + "GwoTcGl2b3RfZm9yX3ByZWRfbmFtZRgGIAEoCRIbChNwaXZvdF9mb3JfYm9k", + "eV9uYW1lGAcgASgJEhcKD2xvb3BfZXhpdF9uYW1lcxgIIAMoCRIYChBsb29w", + "X2VudGVyX25hbWVzGAogAygJEikKCnZhbHVlc19kZWYYCSABKAsyFS50ZW5z", + "b3JmbG93LlZhbHVlc0RlZhIfChdtYXhpbXVtX2l0ZXJhdGlvbnNfbmFtZRgL", + "IAEoCRI6Cg9uZXN0ZWRfY29udGV4dHMYDCADKAsyIS50ZW5zb3JmbG93LkNv", + "bnRyb2xGbG93Q29udGV4dERlZkJwChhvcmcudGVuc29yZmxvdy5mcmFtZXdv", + "cmtCEUNvbnRyb2xGbG93UHJvdG9zUAFaPGdpdGh1Yi5jb20vdGVuc29yZmxv", + "dy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9wcm90b2J1ZvgBAWIG", + "cHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ValuesDef), global::Tensorflow.ValuesDef.Parser, new[]{ "Values", "ExternalValues" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ControlFlowContextDef), global::Tensorflow.ControlFlowContextDef.Parser, new[]{ "CondCtxt", "WhileCtxt" }, new[]{ "Ctxt" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CondContextDef), global::Tensorflow.CondContextDef.Parser, new[]{ "ContextName", "PredName", "PivotName", "Branch", "ValuesDef", "NestedContexts" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.WhileContextDef), global::Tensorflow.WhileContextDef.Parser, new[]{ "ContextName", "ParallelIterations", "BackProp", "SwapMemory", "PivotName", "PivotForPredName", "PivotForBodyName", "LoopExitNames", "LoopEnterNames", "ValuesDef", "MaximumIterationsName", "NestedContexts" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing the values in ControlFlowContext. + /// + public sealed partial class ValuesDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ValuesDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValuesDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValuesDef(ValuesDef other) : this() { + values_ = other.values_.Clone(); + externalValues_ = other.externalValues_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValuesDef Clone() { + return new ValuesDef(this); + } + + /// Field number for the "values" field. + public const int ValuesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_values_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField values_ = new pbc::RepeatedField(); + /// + /// Value names that have been seen in this context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Values { + get { return values_; } + } + + /// Field number for the "external_values" field. + public const int ExternalValuesFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_externalValues_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForString(18), 18); + private readonly pbc::MapField externalValues_ = new pbc::MapField(); + /// + /// Value names referenced by but external to this context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField ExternalValues { + get { return externalValues_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ValuesDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ValuesDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!values_.Equals(other.values_)) return false; + if (!ExternalValues.Equals(other.ExternalValues)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= values_.GetHashCode(); + hash ^= ExternalValues.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + values_.WriteTo(output, _repeated_values_codec); + externalValues_.WriteTo(output, _map_externalValues_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += values_.CalculateSize(_repeated_values_codec); + size += externalValues_.CalculateSize(_map_externalValues_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ValuesDef other) { + if (other == null) { + return; + } + values_.Add(other.values_); + externalValues_.Add(other.externalValues_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + values_.AddEntriesFrom(input, _repeated_values_codec); + break; + } + case 18: { + externalValues_.AddEntriesFrom(input, _map_externalValues_codec); + break; + } + } + } + } + + } + + /// + /// Container for any kind of control flow context. Any other control flow + /// contexts that are added below should also be added here. + /// + public sealed partial class ControlFlowContextDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ControlFlowContextDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ControlFlowContextDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ControlFlowContextDef(ControlFlowContextDef other) : this() { + switch (other.CtxtCase) { + case CtxtOneofCase.CondCtxt: + CondCtxt = other.CondCtxt.Clone(); + break; + case CtxtOneofCase.WhileCtxt: + WhileCtxt = other.WhileCtxt.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ControlFlowContextDef Clone() { + return new ControlFlowContextDef(this); + } + + /// Field number for the "cond_ctxt" field. + public const int CondCtxtFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.CondContextDef CondCtxt { + get { return ctxtCase_ == CtxtOneofCase.CondCtxt ? (global::Tensorflow.CondContextDef) ctxt_ : null; } + set { + ctxt_ = value; + ctxtCase_ = value == null ? CtxtOneofCase.None : CtxtOneofCase.CondCtxt; + } + } + + /// Field number for the "while_ctxt" field. + public const int WhileCtxtFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.WhileContextDef WhileCtxt { + get { return ctxtCase_ == CtxtOneofCase.WhileCtxt ? (global::Tensorflow.WhileContextDef) ctxt_ : null; } + set { + ctxt_ = value; + ctxtCase_ = value == null ? CtxtOneofCase.None : CtxtOneofCase.WhileCtxt; + } + } + + private object ctxt_; + /// Enum of possible cases for the "ctxt" oneof. + public enum CtxtOneofCase { + None = 0, + CondCtxt = 1, + WhileCtxt = 2, + } + private CtxtOneofCase ctxtCase_ = CtxtOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CtxtOneofCase CtxtCase { + get { return ctxtCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearCtxt() { + ctxtCase_ = CtxtOneofCase.None; + ctxt_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ControlFlowContextDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ControlFlowContextDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(CondCtxt, other.CondCtxt)) return false; + if (!object.Equals(WhileCtxt, other.WhileCtxt)) return false; + if (CtxtCase != other.CtxtCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ctxtCase_ == CtxtOneofCase.CondCtxt) hash ^= CondCtxt.GetHashCode(); + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) hash ^= WhileCtxt.GetHashCode(); + hash ^= (int) ctxtCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + output.WriteRawTag(10); + output.WriteMessage(CondCtxt); + } + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + output.WriteRawTag(18); + output.WriteMessage(WhileCtxt); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CondCtxt); + } + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WhileCtxt); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ControlFlowContextDef other) { + if (other == null) { + return; + } + switch (other.CtxtCase) { + case CtxtOneofCase.CondCtxt: + if (CondCtxt == null) { + CondCtxt = new global::Tensorflow.CondContextDef(); + } + CondCtxt.MergeFrom(other.CondCtxt); + break; + case CtxtOneofCase.WhileCtxt: + if (WhileCtxt == null) { + WhileCtxt = new global::Tensorflow.WhileContextDef(); + } + WhileCtxt.MergeFrom(other.WhileCtxt); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.CondContextDef subBuilder = new global::Tensorflow.CondContextDef(); + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + subBuilder.MergeFrom(CondCtxt); + } + input.ReadMessage(subBuilder); + CondCtxt = subBuilder; + break; + } + case 18: { + global::Tensorflow.WhileContextDef subBuilder = new global::Tensorflow.WhileContextDef(); + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + subBuilder.MergeFrom(WhileCtxt); + } + input.ReadMessage(subBuilder); + WhileCtxt = subBuilder; + break; + } + } + } + } + + } + + /// + /// Protocol buffer representing a CondContext object. + /// + public sealed partial class CondContextDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CondContextDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CondContextDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CondContextDef(CondContextDef other) : this() { + contextName_ = other.contextName_; + predName_ = other.predName_; + pivotName_ = other.pivotName_; + branch_ = other.branch_; + valuesDef_ = other.valuesDef_ != null ? other.valuesDef_.Clone() : null; + nestedContexts_ = other.nestedContexts_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CondContextDef Clone() { + return new CondContextDef(this); + } + + /// Field number for the "context_name" field. + public const int ContextNameFieldNumber = 1; + private string contextName_ = ""; + /// + /// Name of the context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ContextName { + get { return contextName_; } + set { + contextName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pred_name" field. + public const int PredNameFieldNumber = 2; + private string predName_ = ""; + /// + /// Name of the pred tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PredName { + get { return predName_; } + set { + predName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pivot_name" field. + public const int PivotNameFieldNumber = 3; + private string pivotName_ = ""; + /// + /// Name of the pivot tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PivotName { + get { return pivotName_; } + set { + pivotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "branch" field. + public const int BranchFieldNumber = 4; + private int branch_; + /// + /// Branch prediction. 0 or 1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Branch { + get { return branch_; } + set { + branch_ = value; + } + } + + /// Field number for the "values_def" field. + public const int ValuesDefFieldNumber = 5; + private global::Tensorflow.ValuesDef valuesDef_; + /// + /// Values and external values in control flow context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.ValuesDef ValuesDef { + get { return valuesDef_; } + set { + valuesDef_ = value; + } + } + + /// Field number for the "nested_contexts" field. + public const int NestedContextsFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_nestedContexts_codec + = pb::FieldCodec.ForMessage(50, global::Tensorflow.ControlFlowContextDef.Parser); + private readonly pbc::RepeatedField nestedContexts_ = new pbc::RepeatedField(); + /// + /// Contexts contained inside this context (e.g. nested conds). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField NestedContexts { + get { return nestedContexts_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as CondContextDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(CondContextDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ContextName != other.ContextName) return false; + if (PredName != other.PredName) return false; + if (PivotName != other.PivotName) return false; + if (Branch != other.Branch) return false; + if (!object.Equals(ValuesDef, other.ValuesDef)) return false; + if(!nestedContexts_.Equals(other.nestedContexts_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); + if (PredName.Length != 0) hash ^= PredName.GetHashCode(); + if (PivotName.Length != 0) hash ^= PivotName.GetHashCode(); + if (Branch != 0) hash ^= Branch.GetHashCode(); + if (valuesDef_ != null) hash ^= ValuesDef.GetHashCode(); + hash ^= nestedContexts_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ContextName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ContextName); + } + if (PredName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(PredName); + } + if (PivotName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(PivotName); + } + if (Branch != 0) { + output.WriteRawTag(32); + output.WriteInt32(Branch); + } + if (valuesDef_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ValuesDef); + } + nestedContexts_.WriteTo(output, _repeated_nestedContexts_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ContextName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ContextName); + } + if (PredName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PredName); + } + if (PivotName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotName); + } + if (Branch != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Branch); + } + if (valuesDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ValuesDef); + } + size += nestedContexts_.CalculateSize(_repeated_nestedContexts_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(CondContextDef other) { + if (other == null) { + return; + } + if (other.ContextName.Length != 0) { + ContextName = other.ContextName; + } + if (other.PredName.Length != 0) { + PredName = other.PredName; + } + if (other.PivotName.Length != 0) { + PivotName = other.PivotName; + } + if (other.Branch != 0) { + Branch = other.Branch; + } + if (other.valuesDef_ != null) { + if (valuesDef_ == null) { + valuesDef_ = new global::Tensorflow.ValuesDef(); + } + ValuesDef.MergeFrom(other.ValuesDef); + } + nestedContexts_.Add(other.nestedContexts_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ContextName = input.ReadString(); + break; + } + case 18: { + PredName = input.ReadString(); + break; + } + case 26: { + PivotName = input.ReadString(); + break; + } + case 32: { + Branch = input.ReadInt32(); + break; + } + case 42: { + if (valuesDef_ == null) { + valuesDef_ = new global::Tensorflow.ValuesDef(); + } + input.ReadMessage(valuesDef_); + break; + } + case 50: { + nestedContexts_.AddEntriesFrom(input, _repeated_nestedContexts_codec); + break; + } + } + } + } + + } + + /// + /// Protocol buffer representing a WhileContext object. + /// + public sealed partial class WhileContextDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WhileContextDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WhileContextDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WhileContextDef(WhileContextDef other) : this() { + contextName_ = other.contextName_; + parallelIterations_ = other.parallelIterations_; + backProp_ = other.backProp_; + swapMemory_ = other.swapMemory_; + pivotName_ = other.pivotName_; + pivotForPredName_ = other.pivotForPredName_; + pivotForBodyName_ = other.pivotForBodyName_; + loopExitNames_ = other.loopExitNames_.Clone(); + loopEnterNames_ = other.loopEnterNames_.Clone(); + valuesDef_ = other.valuesDef_ != null ? other.valuesDef_.Clone() : null; + maximumIterationsName_ = other.maximumIterationsName_; + nestedContexts_ = other.nestedContexts_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WhileContextDef Clone() { + return new WhileContextDef(this); + } + + /// Field number for the "context_name" field. + public const int ContextNameFieldNumber = 1; + private string contextName_ = ""; + /// + /// Name of the context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ContextName { + get { return contextName_; } + set { + contextName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "parallel_iterations" field. + public const int ParallelIterationsFieldNumber = 2; + private int parallelIterations_; + /// + /// The number of iterations allowed to run in parallel. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int ParallelIterations { + get { return parallelIterations_; } + set { + parallelIterations_ = value; + } + } + + /// Field number for the "back_prop" field. + public const int BackPropFieldNumber = 3; + private bool backProp_; + /// + /// Whether backprop is enabled for this while loop. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool BackProp { + get { return backProp_; } + set { + backProp_ = value; + } + } + + /// Field number for the "swap_memory" field. + public const int SwapMemoryFieldNumber = 4; + private bool swapMemory_; + /// + /// Whether GPU-CPU memory swap is enabled for this loop. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool SwapMemory { + get { return swapMemory_; } + set { + swapMemory_ = value; + } + } + + /// Field number for the "pivot_name" field. + public const int PivotNameFieldNumber = 5; + private string pivotName_ = ""; + /// + /// Name of the pivot tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PivotName { + get { return pivotName_; } + set { + pivotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pivot_for_pred_name" field. + public const int PivotForPredNameFieldNumber = 6; + private string pivotForPredName_ = ""; + /// + /// Name of the pivot_for_pred tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PivotForPredName { + get { return pivotForPredName_; } + set { + pivotForPredName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pivot_for_body_name" field. + public const int PivotForBodyNameFieldNumber = 7; + private string pivotForBodyName_ = ""; + /// + /// Name of the pivot_for_body tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PivotForBodyName { + get { return pivotForBodyName_; } + set { + pivotForBodyName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "loop_exit_names" field. + public const int LoopExitNamesFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_loopExitNames_codec + = pb::FieldCodec.ForString(66); + private readonly pbc::RepeatedField loopExitNames_ = new pbc::RepeatedField(); + /// + /// List of names for exit tensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField LoopExitNames { + get { return loopExitNames_; } + } + + /// Field number for the "loop_enter_names" field. + public const int LoopEnterNamesFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_loopEnterNames_codec + = pb::FieldCodec.ForString(82); + private readonly pbc::RepeatedField loopEnterNames_ = new pbc::RepeatedField(); + /// + /// List of names for enter tensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField LoopEnterNames { + get { return loopEnterNames_; } + } + + /// Field number for the "values_def" field. + public const int ValuesDefFieldNumber = 9; + private global::Tensorflow.ValuesDef valuesDef_; + /// + /// Values and external values in control flow context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.ValuesDef ValuesDef { + get { return valuesDef_; } + set { + valuesDef_ = value; + } + } + + /// Field number for the "maximum_iterations_name" field. + public const int MaximumIterationsNameFieldNumber = 11; + private string maximumIterationsName_ = ""; + /// + /// Optional name of the maximum_iterations tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string MaximumIterationsName { + get { return maximumIterationsName_; } + set { + maximumIterationsName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "nested_contexts" field. + public const int NestedContextsFieldNumber = 12; + private static readonly pb::FieldCodec _repeated_nestedContexts_codec + = pb::FieldCodec.ForMessage(98, global::Tensorflow.ControlFlowContextDef.Parser); + private readonly pbc::RepeatedField nestedContexts_ = new pbc::RepeatedField(); + /// + /// Contexts contained inside this context (e.g. nested whiles). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField NestedContexts { + get { return nestedContexts_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as WhileContextDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(WhileContextDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ContextName != other.ContextName) return false; + if (ParallelIterations != other.ParallelIterations) return false; + if (BackProp != other.BackProp) return false; + if (SwapMemory != other.SwapMemory) return false; + if (PivotName != other.PivotName) return false; + if (PivotForPredName != other.PivotForPredName) return false; + if (PivotForBodyName != other.PivotForBodyName) return false; + if(!loopExitNames_.Equals(other.loopExitNames_)) return false; + if(!loopEnterNames_.Equals(other.loopEnterNames_)) return false; + if (!object.Equals(ValuesDef, other.ValuesDef)) return false; + if (MaximumIterationsName != other.MaximumIterationsName) return false; + if(!nestedContexts_.Equals(other.nestedContexts_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); + if (ParallelIterations != 0) hash ^= ParallelIterations.GetHashCode(); + if (BackProp != false) hash ^= BackProp.GetHashCode(); + if (SwapMemory != false) hash ^= SwapMemory.GetHashCode(); + if (PivotName.Length != 0) hash ^= PivotName.GetHashCode(); + if (PivotForPredName.Length != 0) hash ^= PivotForPredName.GetHashCode(); + if (PivotForBodyName.Length != 0) hash ^= PivotForBodyName.GetHashCode(); + hash ^= loopExitNames_.GetHashCode(); + hash ^= loopEnterNames_.GetHashCode(); + if (valuesDef_ != null) hash ^= ValuesDef.GetHashCode(); + if (MaximumIterationsName.Length != 0) hash ^= MaximumIterationsName.GetHashCode(); + hash ^= nestedContexts_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ContextName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ContextName); + } + if (ParallelIterations != 0) { + output.WriteRawTag(16); + output.WriteInt32(ParallelIterations); + } + if (BackProp != false) { + output.WriteRawTag(24); + output.WriteBool(BackProp); + } + if (SwapMemory != false) { + output.WriteRawTag(32); + output.WriteBool(SwapMemory); + } + if (PivotName.Length != 0) { + output.WriteRawTag(42); + output.WriteString(PivotName); + } + if (PivotForPredName.Length != 0) { + output.WriteRawTag(50); + output.WriteString(PivotForPredName); + } + if (PivotForBodyName.Length != 0) { + output.WriteRawTag(58); + output.WriteString(PivotForBodyName); + } + loopExitNames_.WriteTo(output, _repeated_loopExitNames_codec); + if (valuesDef_ != null) { + output.WriteRawTag(74); + output.WriteMessage(ValuesDef); + } + loopEnterNames_.WriteTo(output, _repeated_loopEnterNames_codec); + if (MaximumIterationsName.Length != 0) { + output.WriteRawTag(90); + output.WriteString(MaximumIterationsName); + } + nestedContexts_.WriteTo(output, _repeated_nestedContexts_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ContextName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ContextName); + } + if (ParallelIterations != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ParallelIterations); + } + if (BackProp != false) { + size += 1 + 1; + } + if (SwapMemory != false) { + size += 1 + 1; + } + if (PivotName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotName); + } + if (PivotForPredName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotForPredName); + } + if (PivotForBodyName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotForBodyName); + } + size += loopExitNames_.CalculateSize(_repeated_loopExitNames_codec); + size += loopEnterNames_.CalculateSize(_repeated_loopEnterNames_codec); + if (valuesDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ValuesDef); + } + if (MaximumIterationsName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MaximumIterationsName); + } + size += nestedContexts_.CalculateSize(_repeated_nestedContexts_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(WhileContextDef other) { + if (other == null) { + return; + } + if (other.ContextName.Length != 0) { + ContextName = other.ContextName; + } + if (other.ParallelIterations != 0) { + ParallelIterations = other.ParallelIterations; + } + if (other.BackProp != false) { + BackProp = other.BackProp; + } + if (other.SwapMemory != false) { + SwapMemory = other.SwapMemory; + } + if (other.PivotName.Length != 0) { + PivotName = other.PivotName; + } + if (other.PivotForPredName.Length != 0) { + PivotForPredName = other.PivotForPredName; + } + if (other.PivotForBodyName.Length != 0) { + PivotForBodyName = other.PivotForBodyName; + } + loopExitNames_.Add(other.loopExitNames_); + loopEnterNames_.Add(other.loopEnterNames_); + if (other.valuesDef_ != null) { + if (valuesDef_ == null) { + valuesDef_ = new global::Tensorflow.ValuesDef(); + } + ValuesDef.MergeFrom(other.ValuesDef); + } + if (other.MaximumIterationsName.Length != 0) { + MaximumIterationsName = other.MaximumIterationsName; + } + nestedContexts_.Add(other.nestedContexts_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ContextName = input.ReadString(); + break; + } + case 16: { + ParallelIterations = input.ReadInt32(); + break; + } + case 24: { + BackProp = input.ReadBool(); + break; + } + case 32: { + SwapMemory = input.ReadBool(); + break; + } + case 42: { + PivotName = input.ReadString(); + break; + } + case 50: { + PivotForPredName = input.ReadString(); + break; + } + case 58: { + PivotForBodyName = input.ReadString(); + break; + } + case 66: { + loopExitNames_.AddEntriesFrom(input, _repeated_loopExitNames_codec); + break; + } + case 74: { + if (valuesDef_ == null) { + valuesDef_ = new global::Tensorflow.ValuesDef(); + } + input.ReadMessage(valuesDef_); + break; + } + case 82: { + loopEnterNames_.AddEntriesFrom(input, _repeated_loopEnterNames_codec); + break; + } + case 90: { + MaximumIterationsName = input.ReadString(); + break; + } + case 98: { + nestedContexts_.AddEntriesFrom(input, _repeated_nestedContexts_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs b/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs index 36ce9088..ce08f5ed 100644 --- a/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs +++ b/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs @@ -8,7 +8,7 @@ namespace Tensorflow /// In order for a object to be serialized to and from MetaGraphDef, /// the class must implement to_proto() and from_proto() methods /// - public interface IProtoBuf + public interface IProtoBuf { string name { get; } @@ -17,15 +17,15 @@ namespace Tensorflow /// /// /// - VariableDef to_proto(string export_scope); + TProtoDef to_proto(string export_scope); /// /// Returns a `Variable` object created from `variable_def`. /// /// - /// + /// /// /// - T from_proto(VariableDef variable_def, string import_scope); + TDef from_proto(TProtoDef proto, string import_scope); } } diff --git a/src/TensorFlowNET.Core/Protobuf/README.md b/src/TensorFlowNET.Core/Protobuf/README.md index 0c8bb9ed..2cc8356e 100644 --- a/src/TensorFlowNET.Core/Protobuf/README.md +++ b/src/TensorFlowNET.Core/Protobuf/README.md @@ -1,10 +1,12 @@ ### Download compiler from https://github.com/protocolbuffers/protobuf/releases +Work in command line + ```shell +cd tensorflow + set SRC_DIR=D:/Projects/tensorflow set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf -cd tensorflow - protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto @@ -32,6 +34,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.prot protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/control_flow.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto ``` diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 8d20f34d..95d5520d 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -7,7 +7,7 @@ using System.Text; namespace Tensorflow { - public partial class RefVariable : VariableV1, IProtoBuf + public partial class RefVariable : VariableV1, IProtoBuf { public bool _in_graph_mode = true; public Tensor _initial_value; @@ -288,7 +288,7 @@ namespace Tensorflow throw new NotImplementedException("to_proto RefVariable"); } - public T from_proto(VariableDef variable_def, string import_scope) + public RefVariable from_proto(VariableDef proto, string import_scope) { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 34885776..c147e1b5 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -376,7 +376,7 @@ namespace Tensorflow if (import_scope.EndsWith("/")) import_scope = import_scope.Substring(0, import_scope.Length - 1); - throw new NotImplementedException("prepend_name_scope"); + return $"{import_scope}/{name}"; } else return name; diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index e35923e5..ce267fda 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -10,17 +10,28 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test [TestClass] public class CondTestCases : PythonTest { - [TestMethod] public void testCondTrue() { - with(tf.Session(), sess => + var graph = tf.Graph().as_default(); + + with(tf.Session(graph), sess => { var x = tf.constant(2); var y = tf.constant(5); - var z = control_flow_ops.cond(tf.less(x, y), - () => tf.multiply(x, 17), - () => tf.add(y, 23)); + var pred = tf.less(x, y); + + Func if_true = delegate + { + return tf.multiply(x, 17); + }; + + Func if_false = delegate + { + return tf.add(y, 23); + }; + + var z = control_flow_ops.cond(pred, if_true, if_false); int result = z.eval(sess); assertEquals(result, 34); }); From 6bd08d6c4c6ac651b1000fe0c6933ef5b3bcb5ea Mon Sep 17 00:00:00 2001 From: Haiping Date: Sun, 14 Apr 2019 00:37:00 -0500 Subject: [PATCH 4/4] Update Operation.Output.cs --- src/TensorFlowNET.Core/Operations/Operation.Output.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 5b0b43b3..d7e975bb 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -42,8 +42,8 @@ namespace Tensorflow if (NumControlOutputs > 0) { IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlOutputs); - c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); - for (int i = 0; i < NumControlInputs; i++) + c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlOutputs); + for (int i = 0; i < NumControlOutputs; i++) { var handle = control_output_handle + Marshal.SizeOf() * i; control_outputs[i] = new Operation(*(IntPtr*)handle);