| @@ -4,6 +4,7 @@ using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| using static Tensorflow.CollectionDef; | |||
| using static Tensorflow.MetaGraphDef.Types; | |||
| @@ -95,15 +96,29 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||
| foreach(var value in col.Value.BytesList.Value) | |||
| { | |||
| switch (col.Key) | |||
| { | |||
| case "cond_context": | |||
| var proto = CondContextDef.Parser.ParseFrom(value); | |||
| var condContext = new CondContext().from_proto(proto, import_scope); | |||
| graph.add_to_collection(col.Key, condContext); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||
| } | |||
| } | |||
| var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, | |||
| scope: scope_to_prepend_to_names) as List<RefVariable>; | |||
| var variables = graph.get_collection<RefVariable>(ops.GraphKeys.GLOBAL_VARIABLES, | |||
| scope: scope_to_prepend_to_names); | |||
| var var_list = new Dictionary<string, RefVariable>(); | |||
| variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | |||
| @@ -412,6 +412,11 @@ namespace Tensorflow | |||
| return _collections.ContainsKey(name) ? _collections[name] : null; | |||
| } | |||
| public List<T> get_collection<T>(string name, string scope = null) | |||
| { | |||
| return _collections.ContainsKey(name) ? _collections[name] as List<T> : new List<T>(); | |||
| } | |||
| public object get_collection_ref(string name) | |||
| { | |||
| if (!_collections.ContainsKey(name)) | |||
| @@ -8,7 +8,7 @@ namespace Tensorflow.Operations | |||
| /// <summary> | |||
| /// The context for the conditional construct. | |||
| /// </summary> | |||
| public class CondContext : ControlFlowContext | |||
| public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | |||
| { | |||
| @@ -35,16 +35,20 @@ namespace Tensorflow.Operations | |||
| /// <param name="name">Name of the `CondContext` python object.</param> | |||
| /// <param name="context_def"></param> | |||
| /// <param name="import_scope"></param> | |||
| public CondContext(Tensor pred, | |||
| Tensor pivot, | |||
| int branch, | |||
| public CondContext(Tensor pred = null, | |||
| Tensor pivot = null, | |||
| int? branch = null, | |||
| string name = "cond_text", | |||
| object context_def = null, | |||
| CondContextDef context_def = null, | |||
| string import_scope = null) | |||
| { | |||
| if (pred == null && context_def == null) return; | |||
| _name = ops.get_default_graph().unique_name(name); | |||
| if (context_def != null) | |||
| throw new NotImplementedException("CondContext context_def is not null"); | |||
| if (context_def != null) | |||
| { | |||
| _init_from_proto(context_def, import_scope: import_scope); | |||
| } | |||
| else | |||
| { | |||
| // Initializes the default fields. | |||
| @@ -61,6 +65,18 @@ namespace Tensorflow.Operations | |||
| } | |||
| } | |||
| private void _init_from_proto(CondContextDef context_def, string import_scope = null) | |||
| { | |||
| var g = ops.get_default_graph(); | |||
| _name = ops.prepend_name_scope(context_def.ContextName, import_scope); | |||
| var p1 = ops.prepend_name_scope(context_def.PredName, import_scope); | |||
| _pred = g.as_graph_element(p1) as Tensor; | |||
| var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope); | |||
| _pivot = g.as_graph_element(p2) as Tensor; | |||
| _branch = context_def.Branch; | |||
| __init__(values_def: context_def.ValuesDef, import_scope: import_scope); | |||
| } | |||
| /// <summary> | |||
| /// Add `val` to the current context and its outer context recursively. | |||
| /// </summary> | |||
| @@ -230,6 +246,22 @@ namespace Tensorflow.Operations | |||
| public override void AddInnerOp(Operation resultOp) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| public CondContextDef to_proto(string export_scope) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public CondContext from_proto(CondContextDef proto, string import_scope) | |||
| { | |||
| var ret = new CondContext(context_def: proto, import_scope: import_scope); | |||
| ret.Enter(); | |||
| foreach (var nested_def in proto.NestedContexts) | |||
| throw new NotImplementedException(""); | |||
| ret.Exit(); | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| @@ -32,6 +32,8 @@ namespace Tensorflow.Operations | |||
| protected Stack<IControlFlowContext> _context_stack; | |||
| protected IControlFlowContext _outer_context; | |||
| protected Dictionary<string, ITensorOrOperation> _external_values; | |||
| public ControlFlowContext() | |||
| { | |||
| _context_stack = new Stack<IControlFlowContext>(); | |||
| @@ -40,15 +42,43 @@ namespace Tensorflow.Operations | |||
| public string name { get => _name; } | |||
| protected string _name; | |||
| public void __init__() | |||
| public void __init__(ValuesDef values_def = null, string import_scope = null) | |||
| { | |||
| _outer_context = ops.get_default_graph()._get_control_flow_context(); | |||
| if (values_def != null) | |||
| _init_values_from_proto(values_def, import_scope: import_scope); | |||
| } | |||
| public void __enter__() | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// Initializes values and external_values from `ValuesDef` protocol buffer. | |||
| /// </summary> | |||
| /// <param name="values_def"></param> | |||
| /// <param name="import_scope"></param> | |||
| protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null) | |||
| { | |||
| _external_values = new Dictionary<string, ITensorOrOperation>(); | |||
| foreach (var value in values_def.Values) | |||
| _values.Add(value); | |||
| var g = ops.get_default_graph(); | |||
| foreach(var value in values_def.ExternalValues) | |||
| { | |||
| var k = ops.prepend_name_scope(value.Key, import_scope); | |||
| var v = value.Value; | |||
| _external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope)); | |||
| } | |||
| var op_names = _values.Where(x => !_external_values.ContainsKey(x)) | |||
| .Select(x => x.Split(':')[0]) | |||
| .ToArray(); | |||
| foreach (var op in op_names) | |||
| (g.as_graph_element(op) as Operation)._set_control_flow_context(this); | |||
| } | |||
| public void __exit__() | |||
| { | |||
| } | |||
| @@ -287,7 +287,7 @@ namespace Tensorflow | |||
| // Reset cached inputs. | |||
| _inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None | |||
| // TODO: implement below code dependencies | |||
| // c_api.TF_UpdateEdge(graph, output, input, status); | |||
| c_api.TF_UpdateEdge(graph, output, input, status); | |||
| } | |||
| private void _assert_same_graph(Tensor tensor) | |||
| @@ -330,7 +330,7 @@ namespace Tensorflow | |||
| tensor.op.graph.prevent_fetching(tensor.op); | |||
| // Build the graph for the true branch in a new context. | |||
| var context_t = new CondContext(pred, pivot_1, branch: 1); | |||
| var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | |||
| ITensorOrOperation orig_res_t; | |||
| Tensor res_t; | |||
| try | |||
| @@ -343,7 +343,7 @@ namespace Tensorflow | |||
| context_t.Exit(); | |||
| } | |||
| // Build the graph for the false branch in a new context. | |||
| var context_f = new CondContext(pred, pivot_2, branch: 0); | |||
| var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | |||
| ITensorOrOperation orig_res_f; | |||
| Tensor res_f; | |||
| try | |||
| @@ -411,13 +411,13 @@ namespace Tensorflow | |||
| tensor.op.graph.prevent_fetching(tensor.op); | |||
| // Build the graph for the true branch in a new context. | |||
| var context_t = new CondContext(pred, pivot_1, branch: 1); | |||
| var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | |||
| context_t.Enter(); | |||
| var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||
| context_t.Exit(); | |||
| // Build the graph for the false branch in a new context. | |||
| var context_f = new CondContext(pred, pivot_2, branch: 0); | |||
| var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | |||
| context_f.Enter(); | |||
| var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||
| context_f.Exit(); | |||
| @@ -8,7 +8,7 @@ namespace Tensorflow | |||
| /// In order for a object to be serialized to and from MetaGraphDef, | |||
| /// the class must implement to_proto() and from_proto() methods | |||
| /// </summary> | |||
| public interface IProtoBuf | |||
| public interface IProtoBuf<TProtoDef, TDef> | |||
| { | |||
| string name { get; } | |||
| @@ -17,15 +17,15 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="export_scope"></param> | |||
| /// <returns></returns> | |||
| VariableDef to_proto(string export_scope); | |||
| TProtoDef to_proto(string export_scope); | |||
| /// <summary> | |||
| /// Returns a `Variable` object created from `variable_def`. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="variable_def"></param> | |||
| /// <param name="proto"></param> | |||
| /// <param name="import_scope"></param> | |||
| /// <returns></returns> | |||
| T from_proto<T>(VariableDef variable_def, string import_scope); | |||
| TDef from_proto(TProtoDef proto, string import_scope); | |||
| } | |||
| } | |||
| @@ -1,10 +1,12 @@ | |||
| ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | |||
| Work in command line | |||
| ```shell | |||
| cd tensorflow | |||
| set SRC_DIR=D:/Projects/tensorflow | |||
| set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf | |||
| cd tensorflow | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto | |||
| @@ -32,6 +34,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.prot | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/control_flow.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto | |||
| ``` | |||
| @@ -7,7 +7,7 @@ using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class RefVariable : VariableV1, IProtoBuf | |||
| public partial class RefVariable : VariableV1, IProtoBuf<VariableDef, RefVariable> | |||
| { | |||
| public bool _in_graph_mode = true; | |||
| public Tensor _initial_value; | |||
| @@ -288,7 +288,7 @@ namespace Tensorflow | |||
| throw new NotImplementedException("to_proto RefVariable"); | |||
| } | |||
| public T from_proto<T>(VariableDef variable_def, string import_scope) | |||
| public RefVariable from_proto(VariableDef proto, string import_scope) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| @@ -376,7 +376,7 @@ namespace Tensorflow | |||
| if (import_scope.EndsWith("/")) | |||
| import_scope = import_scope.Substring(0, import_scope.Length - 1); | |||
| throw new NotImplementedException("prepend_name_scope"); | |||
| return $"{import_scope}/{name}"; | |||
| } | |||
| else | |||
| return name; | |||
| @@ -10,17 +10,28 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| [TestClass] | |||
| public class CondTestCases : PythonTest | |||
| { | |||
| [TestMethod] | |||
| public void testCondTrue() | |||
| { | |||
| with(tf.Session(), sess => | |||
| var graph = tf.Graph().as_default(); | |||
| with(tf.Session(graph), sess => | |||
| { | |||
| var x = tf.constant(2); | |||
| var y = tf.constant(5); | |||
| var z = control_flow_ops.cond(tf.less(x, y), | |||
| () => tf.multiply(x, 17), | |||
| () => tf.add(y, 23)); | |||
| var pred = tf.less(x, y); | |||
| Func<ITensorOrOperation> if_true = delegate | |||
| { | |||
| return tf.multiply(x, 17); | |||
| }; | |||
| Func<ITensorOrOperation> if_false = delegate | |||
| { | |||
| return tf.add(y, 23); | |||
| }; | |||
| var z = control_flow_ops.cond(pred, if_true, if_false); | |||
| int result = z.eval(sess); | |||
| assertEquals(result, 34); | |||
| }); | |||