| @@ -4,6 +4,7 @@ using System.Collections.Generic; | |||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | |||||
| using static Tensorflow.CollectionDef; | using static Tensorflow.CollectionDef; | ||||
| using static Tensorflow.MetaGraphDef.Types; | using static Tensorflow.MetaGraphDef.Types; | ||||
| @@ -95,15 +96,29 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||||
| foreach(var value in col.Value.BytesList.Value) | |||||
| { | |||||
| switch (col.Key) | |||||
| { | |||||
| case "cond_context": | |||||
| var proto = CondContextDef.Parser.ParseFrom(value); | |||||
| var condContext = new CondContext().from_proto(proto, import_scope); | |||||
| graph.add_to_collection(col.Key, condContext); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| break; | break; | ||||
| default: | |||||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||||
| } | } | ||||
| } | } | ||||
| var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, | |||||
| scope: scope_to_prepend_to_names) as List<RefVariable>; | |||||
| var variables = graph.get_collection<RefVariable>(ops.GraphKeys.GLOBAL_VARIABLES, | |||||
| scope: scope_to_prepend_to_names); | |||||
| var var_list = new Dictionary<string, RefVariable>(); | var var_list = new Dictionary<string, RefVariable>(); | ||||
| variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | ||||
| @@ -412,6 +412,11 @@ namespace Tensorflow | |||||
| return _collections.ContainsKey(name) ? _collections[name] : null; | return _collections.ContainsKey(name) ? _collections[name] : null; | ||||
| } | } | ||||
| public List<T> get_collection<T>(string name, string scope = null) | |||||
| { | |||||
| return _collections.ContainsKey(name) ? _collections[name] as List<T> : new List<T>(); | |||||
| } | |||||
| public object get_collection_ref(string name) | public object get_collection_ref(string name) | ||||
| { | { | ||||
| if (!_collections.ContainsKey(name)) | if (!_collections.ContainsKey(name)) | ||||
| @@ -8,7 +8,7 @@ namespace Tensorflow.Operations | |||||
| /// <summary> | /// <summary> | ||||
| /// The context for the conditional construct. | /// The context for the conditional construct. | ||||
| /// </summary> | /// </summary> | ||||
| public class CondContext : ControlFlowContext | |||||
| public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | |||||
| { | { | ||||
| @@ -35,16 +35,20 @@ namespace Tensorflow.Operations | |||||
| /// <param name="name">Name of the `CondContext` python object.</param> | /// <param name="name">Name of the `CondContext` python object.</param> | ||||
| /// <param name="context_def"></param> | /// <param name="context_def"></param> | ||||
| /// <param name="import_scope"></param> | /// <param name="import_scope"></param> | ||||
| public CondContext(Tensor pred, | |||||
| Tensor pivot, | |||||
| int branch, | |||||
| public CondContext(Tensor pred = null, | |||||
| Tensor pivot = null, | |||||
| int? branch = null, | |||||
| string name = "cond_text", | string name = "cond_text", | ||||
| object context_def = null, | |||||
| CondContextDef context_def = null, | |||||
| string import_scope = null) | string import_scope = null) | ||||
| { | { | ||||
| if (pred == null && context_def == null) return; | |||||
| _name = ops.get_default_graph().unique_name(name); | _name = ops.get_default_graph().unique_name(name); | ||||
| if (context_def != null) | |||||
| throw new NotImplementedException("CondContext context_def is not null"); | |||||
| if (context_def != null) | |||||
| { | |||||
| _init_from_proto(context_def, import_scope: import_scope); | |||||
| } | |||||
| else | else | ||||
| { | { | ||||
| // Initializes the default fields. | // Initializes the default fields. | ||||
| @@ -61,6 +65,18 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| } | } | ||||
| private void _init_from_proto(CondContextDef context_def, string import_scope = null) | |||||
| { | |||||
| var g = ops.get_default_graph(); | |||||
| _name = ops.prepend_name_scope(context_def.ContextName, import_scope); | |||||
| var p1 = ops.prepend_name_scope(context_def.PredName, import_scope); | |||||
| _pred = g.as_graph_element(p1) as Tensor; | |||||
| var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope); | |||||
| _pivot = g.as_graph_element(p2) as Tensor; | |||||
| _branch = context_def.Branch; | |||||
| __init__(values_def: context_def.ValuesDef, import_scope: import_scope); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Add `val` to the current context and its outer context recursively. | /// Add `val` to the current context and its outer context recursively. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -230,6 +246,22 @@ namespace Tensorflow.Operations | |||||
| public override void AddInnerOp(Operation resultOp) | public override void AddInnerOp(Operation resultOp) | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | |||||
| } | |||||
| public CondContextDef to_proto(string export_scope) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public CondContext from_proto(CondContextDef proto, string import_scope) | |||||
| { | |||||
| var ret = new CondContext(context_def: proto, import_scope: import_scope); | |||||
| ret.Enter(); | |||||
| foreach (var nested_def in proto.NestedContexts) | |||||
| throw new NotImplementedException(""); | |||||
| ret.Exit(); | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -32,6 +32,8 @@ namespace Tensorflow.Operations | |||||
| protected Stack<IControlFlowContext> _context_stack; | protected Stack<IControlFlowContext> _context_stack; | ||||
| protected IControlFlowContext _outer_context; | protected IControlFlowContext _outer_context; | ||||
| protected Dictionary<string, ITensorOrOperation> _external_values; | |||||
| public ControlFlowContext() | public ControlFlowContext() | ||||
| { | { | ||||
| _context_stack = new Stack<IControlFlowContext>(); | _context_stack = new Stack<IControlFlowContext>(); | ||||
| @@ -40,15 +42,43 @@ namespace Tensorflow.Operations | |||||
| public string name { get => _name; } | public string name { get => _name; } | ||||
| protected string _name; | protected string _name; | ||||
| public void __init__() | |||||
| public void __init__(ValuesDef values_def = null, string import_scope = null) | |||||
| { | { | ||||
| _outer_context = ops.get_default_graph()._get_control_flow_context(); | |||||
| if (values_def != null) | |||||
| _init_values_from_proto(values_def, import_scope: import_scope); | |||||
| } | } | ||||
| public void __enter__() | public void __enter__() | ||||
| { | { | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Initializes values and external_values from `ValuesDef` protocol buffer. | |||||
| /// </summary> | |||||
| /// <param name="values_def"></param> | |||||
| /// <param name="import_scope"></param> | |||||
| protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null) | |||||
| { | |||||
| _external_values = new Dictionary<string, ITensorOrOperation>(); | |||||
| foreach (var value in values_def.Values) | |||||
| _values.Add(value); | |||||
| var g = ops.get_default_graph(); | |||||
| foreach(var value in values_def.ExternalValues) | |||||
| { | |||||
| var k = ops.prepend_name_scope(value.Key, import_scope); | |||||
| var v = value.Value; | |||||
| _external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope)); | |||||
| } | |||||
| var op_names = _values.Where(x => !_external_values.ContainsKey(x)) | |||||
| .Select(x => x.Split(':')[0]) | |||||
| .ToArray(); | |||||
| foreach (var op in op_names) | |||||
| (g.as_graph_element(op) as Operation)._set_control_flow_context(this); | |||||
| } | |||||
| public void __exit__() | public void __exit__() | ||||
| { | { | ||||
| } | } | ||||
| @@ -287,7 +287,7 @@ namespace Tensorflow | |||||
| // Reset cached inputs. | // Reset cached inputs. | ||||
| _inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None | _inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None | ||||
| // TODO: implement below code dependencies | // TODO: implement below code dependencies | ||||
| // c_api.TF_UpdateEdge(graph, output, input, status); | |||||
| c_api.TF_UpdateEdge(graph, output, input, status); | |||||
| } | } | ||||
| private void _assert_same_graph(Tensor tensor) | private void _assert_same_graph(Tensor tensor) | ||||
| @@ -330,7 +330,7 @@ namespace Tensorflow | |||||
| tensor.op.graph.prevent_fetching(tensor.op); | tensor.op.graph.prevent_fetching(tensor.op); | ||||
| // Build the graph for the true branch in a new context. | // Build the graph for the true branch in a new context. | ||||
| var context_t = new CondContext(pred, pivot_1, branch: 1); | |||||
| var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | |||||
| ITensorOrOperation orig_res_t; | ITensorOrOperation orig_res_t; | ||||
| Tensor res_t; | Tensor res_t; | ||||
| try | try | ||||
| @@ -343,7 +343,7 @@ namespace Tensorflow | |||||
| context_t.Exit(); | context_t.Exit(); | ||||
| } | } | ||||
| // Build the graph for the false branch in a new context. | // Build the graph for the false branch in a new context. | ||||
| var context_f = new CondContext(pred, pivot_2, branch: 0); | |||||
| var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | |||||
| ITensorOrOperation orig_res_f; | ITensorOrOperation orig_res_f; | ||||
| Tensor res_f; | Tensor res_f; | ||||
| try | try | ||||
| @@ -411,13 +411,13 @@ namespace Tensorflow | |||||
| tensor.op.graph.prevent_fetching(tensor.op); | tensor.op.graph.prevent_fetching(tensor.op); | ||||
| // Build the graph for the true branch in a new context. | // Build the graph for the true branch in a new context. | ||||
| var context_t = new CondContext(pred, pivot_1, branch: 1); | |||||
| var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | |||||
| context_t.Enter(); | context_t.Enter(); | ||||
| var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | ||||
| context_t.Exit(); | context_t.Exit(); | ||||
| // Build the graph for the false branch in a new context. | // Build the graph for the false branch in a new context. | ||||
| var context_f = new CondContext(pred, pivot_2, branch: 0); | |||||
| var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | |||||
| context_f.Enter(); | context_f.Enter(); | ||||
| var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | ||||
| context_f.Exit(); | context_f.Exit(); | ||||
| @@ -8,7 +8,7 @@ namespace Tensorflow | |||||
| /// In order for a object to be serialized to and from MetaGraphDef, | /// In order for a object to be serialized to and from MetaGraphDef, | ||||
| /// the class must implement to_proto() and from_proto() methods | /// the class must implement to_proto() and from_proto() methods | ||||
| /// </summary> | /// </summary> | ||||
| public interface IProtoBuf | |||||
| public interface IProtoBuf<TProtoDef, TDef> | |||||
| { | { | ||||
| string name { get; } | string name { get; } | ||||
| @@ -17,15 +17,15 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="export_scope"></param> | /// <param name="export_scope"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| VariableDef to_proto(string export_scope); | |||||
| TProtoDef to_proto(string export_scope); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns a `Variable` object created from `variable_def`. | /// Returns a `Variable` object created from `variable_def`. | ||||
| /// </summary> | /// </summary> | ||||
| /// <typeparam name="T"></typeparam> | /// <typeparam name="T"></typeparam> | ||||
| /// <param name="variable_def"></param> | |||||
| /// <param name="proto"></param> | |||||
| /// <param name="import_scope"></param> | /// <param name="import_scope"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| T from_proto<T>(VariableDef variable_def, string import_scope); | |||||
| TDef from_proto(TProtoDef proto, string import_scope); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,12 @@ | |||||
| ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | ||||
| Work in command line | |||||
| ```shell | ```shell | ||||
| cd tensorflow | |||||
| set SRC_DIR=D:/Projects/tensorflow | set SRC_DIR=D:/Projects/tensorflow | ||||
| set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf | set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf | ||||
| cd tensorflow | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto | ||||
| @@ -32,6 +34,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.prot | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/control_flow.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto | ||||
| ``` | ``` | ||||
| @@ -7,7 +7,7 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class RefVariable : VariableV1, IProtoBuf | |||||
| public partial class RefVariable : VariableV1, IProtoBuf<VariableDef, RefVariable> | |||||
| { | { | ||||
| public bool _in_graph_mode = true; | public bool _in_graph_mode = true; | ||||
| public Tensor _initial_value; | public Tensor _initial_value; | ||||
| @@ -288,7 +288,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("to_proto RefVariable"); | throw new NotImplementedException("to_proto RefVariable"); | ||||
| } | } | ||||
| public T from_proto<T>(VariableDef variable_def, string import_scope) | |||||
| public RefVariable from_proto(VariableDef proto, string import_scope) | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| @@ -376,7 +376,7 @@ namespace Tensorflow | |||||
| if (import_scope.EndsWith("/")) | if (import_scope.EndsWith("/")) | ||||
| import_scope = import_scope.Substring(0, import_scope.Length - 1); | import_scope = import_scope.Substring(0, import_scope.Length - 1); | ||||
| throw new NotImplementedException("prepend_name_scope"); | |||||
| return $"{import_scope}/{name}"; | |||||
| } | } | ||||
| else | else | ||||
| return name; | return name; | ||||
| @@ -10,17 +10,28 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
| [TestClass] | [TestClass] | ||||
| public class CondTestCases : PythonTest | public class CondTestCases : PythonTest | ||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void testCondTrue() | public void testCondTrue() | ||||
| { | { | ||||
| with(tf.Session(), sess => | |||||
| var graph = tf.Graph().as_default(); | |||||
| with(tf.Session(graph), sess => | |||||
| { | { | ||||
| var x = tf.constant(2); | var x = tf.constant(2); | ||||
| var y = tf.constant(5); | var y = tf.constant(5); | ||||
| var z = control_flow_ops.cond(tf.less(x, y), | |||||
| () => tf.multiply(x, 17), | |||||
| () => tf.add(y, 23)); | |||||
| var pred = tf.less(x, y); | |||||
| Func<ITensorOrOperation> if_true = delegate | |||||
| { | |||||
| return tf.multiply(x, 17); | |||||
| }; | |||||
| Func<ITensorOrOperation> if_false = delegate | |||||
| { | |||||
| return tf.add(y, 23); | |||||
| }; | |||||
| var z = control_flow_ops.cond(pred, if_true, if_false); | |||||
| int result = z.eval(sess); | int result = z.eval(sess); | ||||
| assertEquals(result, 34); | assertEquals(result, 34); | ||||
| }); | }); | ||||