# Conflicts: # src/TensorFlowNET.Core/Operations/Operation.cs # test/TensorFlowNET.UnitTest/PythonTest.cs # test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cstags/v0.9
| @@ -27,7 +27,7 @@ namespace Tensorflow | |||||
| public static Tensor asin(Tensor x, string name = null) | public static Tensor asin(Tensor x, string name = null) | ||||
| => gen_math_ops.asin(x, name); | => gen_math_ops.asin(x, name); | ||||
| public static Tensor add(Tensor a, Tensor b) | |||||
| public static Tensor add<Tx, Ty>(Tx a, Ty b) | |||||
| => gen_math_ops.add(a, b); | => gen_math_ops.add(a, b); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -251,7 +251,7 @@ namespace Tensorflow | |||||
| public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null) | public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null) | ||||
| => gen_math_ops.minimum(x, y, name: name); | => gen_math_ops.minimum(x, y, name: name); | ||||
| public static Tensor multiply(Tensor x, Tensor y) | |||||
| public static Tensor multiply<Tx, Ty>(Tx x, Ty y) | |||||
| => gen_math_ops.mul(x, y); | => gen_math_ops.mul(x, y); | ||||
| public static Tensor negative(Tensor x, string name = null) | public static Tensor negative(Tensor x, string name = null) | ||||
| @@ -4,6 +4,7 @@ using System.Collections.Generic; | |||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | |||||
| using static Tensorflow.CollectionDef; | using static Tensorflow.CollectionDef; | ||||
| using static Tensorflow.MetaGraphDef.Types; | using static Tensorflow.MetaGraphDef.Types; | ||||
| @@ -95,15 +96,29 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||||
| foreach(var value in col.Value.BytesList.Value) | |||||
| { | |||||
| switch (col.Key) | |||||
| { | |||||
| case "cond_context": | |||||
| var proto = CondContextDef.Parser.ParseFrom(value); | |||||
| var condContext = new CondContext().from_proto(proto, import_scope); | |||||
| graph.add_to_collection(col.Key, condContext); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| break; | break; | ||||
| default: | |||||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||||
| } | } | ||||
| } | } | ||||
| var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, | |||||
| scope: scope_to_prepend_to_names) as List<RefVariable>; | |||||
| var variables = graph.get_collection<RefVariable>(ops.GraphKeys.GLOBAL_VARIABLES, | |||||
| scope: scope_to_prepend_to_names); | |||||
| var var_list = new Dictionary<string, RefVariable>(); | var var_list = new Dictionary<string, RefVariable>(); | ||||
| variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | ||||
| @@ -412,6 +412,11 @@ namespace Tensorflow | |||||
| return _collections.ContainsKey(name) ? _collections[name] : null; | return _collections.ContainsKey(name) ? _collections[name] : null; | ||||
| } | } | ||||
| public List<T> get_collection<T>(string name, string scope = null) | |||||
| { | |||||
| return _collections.ContainsKey(name) ? _collections[name] as List<T> : new List<T>(); | |||||
| } | |||||
| public object get_collection_ref(string name) | public object get_collection_ref(string name) | ||||
| { | { | ||||
| if (!_collections.ContainsKey(name)) | if (!_collections.ContainsKey(name)) | ||||
| @@ -8,7 +8,7 @@ namespace Tensorflow.Operations | |||||
| /// <summary> | /// <summary> | ||||
| /// The context for the conditional construct. | /// The context for the conditional construct. | ||||
| /// </summary> | /// </summary> | ||||
| public class CondContext : ControlFlowContext | |||||
| public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | |||||
| { | { | ||||
| @@ -35,16 +35,20 @@ namespace Tensorflow.Operations | |||||
| /// <param name="name">Name of the `CondContext` python object.</param> | /// <param name="name">Name of the `CondContext` python object.</param> | ||||
| /// <param name="context_def"></param> | /// <param name="context_def"></param> | ||||
| /// <param name="import_scope"></param> | /// <param name="import_scope"></param> | ||||
| public CondContext(Tensor pred, | |||||
| Tensor pivot, | |||||
| int branch, | |||||
| public CondContext(Tensor pred = null, | |||||
| Tensor pivot = null, | |||||
| int? branch = null, | |||||
| string name = "cond_text", | string name = "cond_text", | ||||
| object context_def = null, | |||||
| CondContextDef context_def = null, | |||||
| string import_scope = null) | string import_scope = null) | ||||
| { | { | ||||
| if (pred == null && context_def == null) return; | |||||
| _name = ops.get_default_graph().unique_name(name); | _name = ops.get_default_graph().unique_name(name); | ||||
| if (context_def != null) | |||||
| throw new NotImplementedException("CondContext context_def is not null"); | |||||
| if (context_def != null) | |||||
| { | |||||
| _init_from_proto(context_def, import_scope: import_scope); | |||||
| } | |||||
| else | else | ||||
| { | { | ||||
| // Initializes the default fields. | // Initializes the default fields. | ||||
| @@ -61,6 +65,18 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| } | } | ||||
| private void _init_from_proto(CondContextDef context_def, string import_scope = null) | |||||
| { | |||||
| var g = ops.get_default_graph(); | |||||
| _name = ops.prepend_name_scope(context_def.ContextName, import_scope); | |||||
| var p1 = ops.prepend_name_scope(context_def.PredName, import_scope); | |||||
| _pred = g.as_graph_element(p1) as Tensor; | |||||
| var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope); | |||||
| _pivot = g.as_graph_element(p2) as Tensor; | |||||
| _branch = context_def.Branch; | |||||
| __init__(values_def: context_def.ValuesDef, import_scope: import_scope); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Add `val` to the current context and its outer context recursively. | /// Add `val` to the current context and its outer context recursively. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -230,6 +246,22 @@ namespace Tensorflow.Operations | |||||
| public override void AddInnerOp(Operation resultOp) | public override void AddInnerOp(Operation resultOp) | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | |||||
| } | |||||
| public CondContextDef to_proto(string export_scope) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public CondContext from_proto(CondContextDef proto, string import_scope) | |||||
| { | |||||
| var ret = new CondContext(context_def: proto, import_scope: import_scope); | |||||
| ret.Enter(); | |||||
| foreach (var nested_def in proto.NestedContexts) | |||||
| throw new NotImplementedException(""); | |||||
| ret.Exit(); | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -32,6 +32,8 @@ namespace Tensorflow.Operations | |||||
| protected Stack<IControlFlowContext> _context_stack; | protected Stack<IControlFlowContext> _context_stack; | ||||
| protected IControlFlowContext _outer_context; | protected IControlFlowContext _outer_context; | ||||
| protected Dictionary<string, ITensorOrOperation> _external_values; | |||||
| public ControlFlowContext() | public ControlFlowContext() | ||||
| { | { | ||||
| _context_stack = new Stack<IControlFlowContext>(); | _context_stack = new Stack<IControlFlowContext>(); | ||||
| @@ -40,15 +42,43 @@ namespace Tensorflow.Operations | |||||
| public string name { get => _name; } | public string name { get => _name; } | ||||
| protected string _name; | protected string _name; | ||||
| public void __init__() | |||||
| public void __init__(ValuesDef values_def = null, string import_scope = null) | |||||
| { | { | ||||
| _outer_context = ops.get_default_graph()._get_control_flow_context(); | |||||
| if (values_def != null) | |||||
| _init_values_from_proto(values_def, import_scope: import_scope); | |||||
| } | } | ||||
| public void __enter__() | public void __enter__() | ||||
| { | { | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Initializes values and external_values from `ValuesDef` protocol buffer. | |||||
| /// </summary> | |||||
| /// <param name="values_def"></param> | |||||
| /// <param name="import_scope"></param> | |||||
| protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null) | |||||
| { | |||||
| _external_values = new Dictionary<string, ITensorOrOperation>(); | |||||
| foreach (var value in values_def.Values) | |||||
| _values.Add(value); | |||||
| var g = ops.get_default_graph(); | |||||
| foreach(var value in values_def.ExternalValues) | |||||
| { | |||||
| var k = ops.prepend_name_scope(value.Key, import_scope); | |||||
| var v = value.Value; | |||||
| _external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope)); | |||||
| } | |||||
| var op_names = _values.Where(x => !_external_values.ContainsKey(x)) | |||||
| .Select(x => x.Split(':')[0]) | |||||
| .ToArray(); | |||||
| foreach (var op in op_names) | |||||
| (g.as_graph_element(op) as Operation)._set_control_flow_context(this); | |||||
| } | |||||
| public void __exit__() | public void __exit__() | ||||
| { | { | ||||
| } | } | ||||
| @@ -42,8 +42,8 @@ namespace Tensorflow | |||||
| if (NumControlOutputs > 0) | if (NumControlOutputs > 0) | ||||
| { | { | ||||
| IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | ||||
| c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); | |||||
| for (int i = 0; i < NumControlInputs; i++) | |||||
| c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlOutputs); | |||||
| for (int i = 0; i < NumControlOutputs; i++) | |||||
| { | { | ||||
| var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i; | var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i; | ||||
| control_outputs[i] = new Operation(*(IntPtr*)handle); | control_outputs[i] = new Operation(*(IntPtr*)handle); | ||||
| @@ -1,319 +1,318 @@ | |||||
| using Google.Protobuf.Collections; | |||||
| //using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| using Google.Protobuf.Collections; | |||||
| //using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | { | ||||
| /// <summary> | |||||
| /// Represents a graph node that performs computation on tensors. | |||||
| /// | |||||
| /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or | |||||
| /// more `Tensor` objects as input, and produces zero or more `Tensor` | |||||
| /// objects as output. Objects of type `Operation` are created by | |||||
| /// calling an op constructor(such as `tf.matmul`) | |||||
| /// or `tf.Graph.create_op`. | |||||
| /// | |||||
| /// For example `c = tf.matmul(a, b)` creates an `Operation` of type | |||||
| /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` | |||||
| /// as output. | |||||
| /// | |||||
| /// After the graph has been launched in a session, an `Operation` can | |||||
| /// be executed by passing it to | |||||
| /// `tf.Session.run`. | |||||
| /// <summary> | |||||
| /// Represents a graph node that performs computation on tensors. | |||||
| /// | |||||
| /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or | |||||
| /// more `Tensor` objects as input, and produces zero or more `Tensor` | |||||
| /// objects as output. Objects of type `Operation` are created by | |||||
| /// calling an op constructor(such as `tf.matmul`) | |||||
| /// or `tf.Graph.create_op`. | |||||
| /// | |||||
| /// For example `c = tf.matmul(a, b)` creates an `Operation` of type | |||||
| /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` | |||||
| /// as output. | |||||
| /// | |||||
| /// After the graph has been launched in a session, an `Operation` can | |||||
| /// be executed by passing it to | |||||
| /// `tf.Session.run`. | |||||
| /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. | /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. | ||||
| /// </summary> | |||||
| public partial class Operation : ITensorOrOperation | |||||
| { | |||||
| private readonly IntPtr _handle; // _c_op in python | |||||
| private readonly IntPtr _operDesc; | |||||
| private Graph _graph; | |||||
| //[JsonIgnore] | |||||
| public Graph graph => _graph; | |||||
| //[JsonIgnore] | |||||
| public int _id => _id_value; | |||||
| //[JsonIgnore] | |||||
| public int _id_value; | |||||
| public string type => OpType; | |||||
| //[JsonIgnore] | |||||
| public Operation op => this; | |||||
| public TF_DataType dtype => TF_DataType.DtInvalid; | |||||
| private Status status = new Status(); | |||||
| public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
| public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||||
| public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
| private NodeDef _node_def; | |||||
| public NodeDef node_def | |||||
| { | |||||
| get | |||||
| { | |||||
| if(_node_def == null) | |||||
| _node_def = GetNodeDef(); | |||||
| return _node_def; | |||||
| } | |||||
| } | |||||
| public Operation(IntPtr handle, Graph g=null) | |||||
| { | |||||
| if (handle == IntPtr.Zero) | |||||
| return; | |||||
| _handle = handle; | |||||
| _graph = g ?? ops.get_default_graph(); | |||||
| _outputs = new Tensor[NumOutputs]; | |||||
| for (int i = 0; i < NumOutputs; i++) | |||||
| _outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| _control_flow_context = graph._get_control_flow_context(); | |||||
| // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | |||||
| } | |||||
| public Operation(Graph g, string opType, string oper_name) | |||||
| { | |||||
| _graph = g; | |||||
| _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
| c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | |||||
| _handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| _control_flow_context = graph._get_control_flow_context(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates an `Operation`. | |||||
| /// </summary> | |||||
| /// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param> | |||||
| /// <param name="g">`Graph`. The parent graph.</param> | |||||
| /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param> | |||||
| /// <param name="output_types">list of `DType` objects.</param> | |||||
| /// <param name="control_inputs"> | |||||
| /// list of operations or tensors from which to have a | |||||
| /// control dependency. | |||||
| /// </param> | |||||
| /// <param name="input_types"> | |||||
| /// List of `DType` objects representing the | |||||
| /// types of the tensors accepted by the `Operation`. By default | |||||
| /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect | |||||
| /// reference-typed inputs must specify these explicitly. | |||||
| /// </param> | |||||
| /// <param name="original_op"></param> | |||||
| /// <param name="op_def"></param> | |||||
| public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
| { | |||||
| _graph = g; | |||||
| // Build the list of control inputs. | |||||
| var control_input_ops = new List<Operation>(); | |||||
| if(control_inputs != null) | |||||
| { | |||||
| foreach(var c in control_inputs) | |||||
| { | |||||
| switch (c) | |||||
| { | |||||
| case Operation c1: | |||||
| control_input_ops.Add(c1); | |||||
| break; | |||||
| case Tensor tensor: | |||||
| control_input_ops.Add(tensor.op); | |||||
| break; | |||||
| // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented | |||||
| //case IndexedSlices islices: | |||||
| // control_input_ops.Add(islices.op); | |||||
| // break; | |||||
| default: | |||||
| throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| _control_flow_context = graph._get_control_flow_context(); | |||||
| // This will be set by self.inputs. | |||||
| if (op_def == null) | |||||
| op_def = g.GetOpDef(node_def.Op); | |||||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||||
| (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
| // Initialize self._outputs. | |||||
| output_types = new TF_DataType[NumOutputs]; | |||||
| for (int i = 0; i < NumOutputs; i++) | |||||
| output_types[i] = OutputType(i); | |||||
| _outputs = new Tensor[NumOutputs]; | |||||
| for (int i = 0; i < NumOutputs; i++) | |||||
| _outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
| graph._add_op(this); | |||||
| if (_handle != IntPtr.Zero) | |||||
| _control_flow_post_processing(); | |||||
| } | |||||
| public void run(FeedItem[] feed_dict = null, Session session = null) | |||||
| { | |||||
| ops._run_using_default_session(this, feed_dict, graph, session); | |||||
| } | |||||
| private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | |||||
| { | |||||
| var grouped_inputs = new List<object>(); | |||||
| int i = 0; | |||||
| int input_len = 0; | |||||
| bool is_sequence = false; | |||||
| foreach (var input_arg in op_def.InputArg) | |||||
| { | |||||
| if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||||
| { | |||||
| input_len = (int)attrs[input_arg.NumberAttr].I; | |||||
| is_sequence = true; | |||||
| } | |||||
| else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||||
| { | |||||
| input_len = attrs[input_arg.TypeListAttr].List.Type.Count; | |||||
| is_sequence = true; | |||||
| } | |||||
| else | |||||
| { | |||||
| input_len = 1; | |||||
| is_sequence = false; | |||||
| } | |||||
| if (is_sequence) | |||||
| grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray()); | |||||
| else | |||||
| grouped_inputs.Add(inputs[i]); | |||||
| i += input_len; | |||||
| } | |||||
| return grouped_inputs.ToArray(); | |||||
| } | |||||
| public object get_attr(string name) | |||||
| { | |||||
| AttrValue x = null; | |||||
| using (var buf = new Buffer()) | |||||
| { | |||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
| status.Check(true); | |||||
| x = AttrValue.Parser.ParseFrom(buf); | |||||
| } | |||||
| string oneof_value = x.ValueCase.ToString(); | |||||
| if (string.IsNullOrEmpty(oneof_value)) | |||||
| return null; | |||||
| if(oneof_value == "list") | |||||
| throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | |||||
| if (oneof_value == "type") | |||||
| return x.Type; | |||||
| object result = x.GetType().GetProperty(oneof_value).GetValue(x); | |||||
| if (result is Google.Protobuf.ByteString byteString) | |||||
| return byteString.ToStringUtf8(); | |||||
| return result; | |||||
| } | |||||
| public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | |||||
| { | |||||
| return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||||
| } | |||||
| private NodeDef GetNodeDef() | |||||
| { | |||||
| using (var s = new Status()) | |||||
| using (var buffer = new Buffer()) | |||||
| { | |||||
| c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||||
| s.Check(); | |||||
| return NodeDef.Parser.ParseFrom(buffer); | |||||
| } | |||||
| } | |||||
| public override string ToString() | |||||
| { | |||||
| return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}"; | |||||
| } | |||||
| public static implicit operator Operation(IntPtr handle) => new Operation(handle); | |||||
| public static implicit operator IntPtr(Operation op) => op._handle; | |||||
| public override bool Equals(object obj) | |||||
| { | |||||
| switch (obj) | |||||
| { | |||||
| case IntPtr val: | |||||
| return val == _handle; | |||||
| case Operation val: | |||||
| return val._handle == _handle; | |||||
| } | |||||
| return base.Equals(obj); | |||||
| /// </summary> | |||||
| public partial class Operation : ITensorOrOperation | |||||
| { | |||||
| private readonly IntPtr _handle; // _c_op in python | |||||
| private readonly IntPtr _operDesc; | |||||
| private Graph _graph; | |||||
| //[JsonIgnore] | |||||
| public Graph graph => _graph; | |||||
| //[JsonIgnore] | |||||
| public int _id => _id_value; | |||||
| //[JsonIgnore] | |||||
| public int _id_value; | |||||
| public string type => OpType; | |||||
| //[JsonIgnore] | |||||
| public Operation op => this; | |||||
| public TF_DataType dtype => TF_DataType.DtInvalid; | |||||
| private Status status = new Status(); | |||||
| public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
| public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||||
| public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
| private NodeDef _node_def; | |||||
| public NodeDef node_def | |||||
| { | |||||
| get | |||||
| { | |||||
| if(_node_def == null) | |||||
| _node_def = GetNodeDef(); | |||||
| return _node_def; | |||||
| } | |||||
| } | |||||
| public Operation(IntPtr handle, Graph g=null) | |||||
| { | |||||
| if (handle == IntPtr.Zero) | |||||
| return; | |||||
| _handle = handle; | |||||
| _graph = g ?? ops.get_default_graph(); | |||||
| _outputs = new Tensor[NumOutputs]; | |||||
| for (int i = 0; i < NumOutputs; i++) | |||||
| _outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| _control_flow_context = graph._get_control_flow_context(); | |||||
| // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | |||||
| } | |||||
| public Operation(Graph g, string opType, string oper_name) | |||||
| { | |||||
| _graph = g; | |||||
| _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
| c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | |||||
| _handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| _control_flow_context = graph._get_control_flow_context(); | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Update the input to this operation at the given index. | |||||
| /// | |||||
| /// NOTE: This is for TF internal use only.Please don't use it. | |||||
| /// </summary> | |||||
| /// <param name="index">the index of the input to update.</param> | |||||
| /// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||||
| public void _update_input(int index, Tensor tensor) | |||||
| { | |||||
| _assert_same_graph(tensor); | |||||
| var input = _tf_input(index); | |||||
| /// <summary> | |||||
| /// Creates an `Operation`. | |||||
| /// </summary> | |||||
| /// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param> | |||||
| /// <param name="g">`Graph`. The parent graph.</param> | |||||
| /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param> | |||||
| /// <param name="output_types">list of `DType` objects.</param> | |||||
| /// <param name="control_inputs"> | |||||
| /// list of operations or tensors from which to have a | |||||
| /// control dependency. | |||||
| /// </param> | |||||
| /// <param name="input_types"> | |||||
| /// List of `DType` objects representing the | |||||
| /// types of the tensors accepted by the `Operation`. By default | |||||
| /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect | |||||
| /// reference-typed inputs must specify these explicitly. | |||||
| /// </param> | |||||
| /// <param name="original_op"></param> | |||||
| /// <param name="op_def"></param> | |||||
| public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
| { | |||||
| _graph = g; | |||||
| // Build the list of control inputs. | |||||
| var control_input_ops = new List<Operation>(); | |||||
| if(control_inputs != null) | |||||
| { | |||||
| foreach(var c in control_inputs) | |||||
| { | |||||
| switch (c) | |||||
| { | |||||
| case Operation c1: | |||||
| control_input_ops.Add(c1); | |||||
| break; | |||||
| case Tensor tensor: | |||||
| control_input_ops.Add(tensor.op); | |||||
| break; | |||||
| // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented | |||||
| //case IndexedSlices islices: | |||||
| // control_input_ops.Add(islices.op); | |||||
| // break; | |||||
| default: | |||||
| throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| _control_flow_context = graph._get_control_flow_context(); | |||||
| // This will be set by self.inputs. | |||||
| if (op_def == null) | |||||
| op_def = g.GetOpDef(node_def.Op); | |||||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||||
| (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
| // Initialize self._outputs. | |||||
| output_types = new TF_DataType[NumOutputs]; | |||||
| for (int i = 0; i < NumOutputs; i++) | |||||
| output_types[i] = OutputType(i); | |||||
| _outputs = new Tensor[NumOutputs]; | |||||
| for (int i = 0; i < NumOutputs; i++) | |||||
| _outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
| graph._add_op(this); | |||||
| if (_handle != IntPtr.Zero) | |||||
| _control_flow_post_processing(); | |||||
| } | |||||
| public void run(FeedItem[] feed_dict = null, Session session = null) | |||||
| { | |||||
| ops._run_using_default_session(this, feed_dict, graph, session); | |||||
| } | |||||
| private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | |||||
| { | |||||
| var grouped_inputs = new List<object>(); | |||||
| int i = 0; | |||||
| int input_len = 0; | |||||
| bool is_sequence = false; | |||||
| foreach (var input_arg in op_def.InputArg) | |||||
| { | |||||
| if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||||
| { | |||||
| input_len = (int)attrs[input_arg.NumberAttr].I; | |||||
| is_sequence = true; | |||||
| } | |||||
| else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||||
| { | |||||
| input_len = attrs[input_arg.TypeListAttr].List.Type.Count; | |||||
| is_sequence = true; | |||||
| } | |||||
| else | |||||
| { | |||||
| input_len = 1; | |||||
| is_sequence = false; | |||||
| } | |||||
| if (is_sequence) | |||||
| grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray()); | |||||
| else | |||||
| grouped_inputs.Add(inputs[i]); | |||||
| i += input_len; | |||||
| } | |||||
| return grouped_inputs.ToArray(); | |||||
| } | |||||
| public object get_attr(string name) | |||||
| { | |||||
| AttrValue x = null; | |||||
| using (var buf = new Buffer()) | |||||
| { | |||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
| status.Check(true); | |||||
| x = AttrValue.Parser.ParseFrom(buf); | |||||
| } | |||||
| string oneof_value = x.ValueCase.ToString(); | |||||
| if (string.IsNullOrEmpty(oneof_value)) | |||||
| return null; | |||||
| if(oneof_value == "list") | |||||
| throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | |||||
| if (oneof_value == "type") | |||||
| return x.Type; | |||||
| object result = x.GetType().GetProperty(oneof_value).GetValue(x); | |||||
| if (result is Google.Protobuf.ByteString byteString) | |||||
| return byteString.ToStringUtf8(); | |||||
| return result; | |||||
| } | |||||
| public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | |||||
| { | |||||
| return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||||
| } | |||||
| private NodeDef GetNodeDef() | |||||
| { | |||||
| using (var s = new Status()) | |||||
| using (var buffer = new Buffer()) | |||||
| { | |||||
| c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||||
| s.Check(); | |||||
| return NodeDef.Parser.ParseFrom(buffer); | |||||
| } | |||||
| } | |||||
| public override string ToString() | |||||
| { | |||||
| return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}"; | |||||
| } | |||||
| public static implicit operator Operation(IntPtr handle) => new Operation(handle); | |||||
| public static implicit operator IntPtr(Operation op) => op._handle; | |||||
| public override bool Equals(object obj) | |||||
| { | |||||
| switch (obj) | |||||
| { | |||||
| case IntPtr val: | |||||
| return val == _handle; | |||||
| case Operation val: | |||||
| return val._handle == _handle; | |||||
| } | |||||
| return base.Equals(obj); | |||||
| } | |||||
| /// <summary> | |||||
| /// Update the input to this operation at the given index. | |||||
| /// | |||||
| /// NOTE: This is for TF internal use only.Please don't use it. | |||||
| /// </summary> | |||||
| /// <param name="index">the index of the input to update.</param> | |||||
| /// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||||
| public void _update_input(int index, Tensor tensor) | |||||
| { | |||||
| _assert_same_graph(tensor); | |||||
| var input = _tf_input(index); | |||||
| var output = tensor._as_tf_output(); | var output = tensor._as_tf_output(); | ||||
| // Reset cached inputs. | // Reset cached inputs. | ||||
| _inputs = null; | _inputs = null; | ||||
| // after the c_api call next time _inputs is accessed | // after the c_api call next time _inputs is accessed | ||||
| // the updated inputs are reloaded from the c_api | |||||
| c_api.TF_UpdateEdge(_graph, output, input, status); | |||||
| //var updated_inputs = inputs; | |||||
| } | |||||
| private void _assert_same_graph(Tensor tensor) | |||||
| { | |||||
| //TODO: implement | |||||
| } | |||||
| /// <summary> | |||||
| /// Create and return a new TF_Output for output_idx'th output of this op. | |||||
| /// </summary> | |||||
| public TF_Output _tf_output(int output_idx) | |||||
| { | |||||
| var tf_output = new TF_Output(op, output_idx); | |||||
| return tf_output; | |||||
| } | |||||
| /// <summary> | |||||
| /// Create and return a new TF_Input for input_idx'th input of this op. | |||||
| /// </summary> | |||||
| public TF_Input _tf_input(int input_idx) | |||||
| { | |||||
| var tf_input = new TF_Input(op, input_idx); | |||||
| return tf_input; | |||||
| } | |||||
| } | |||||
| } | |||||
| // the updated inputs are reloaded from the c_api | |||||
| c_api.TF_UpdateEdge(_graph, output, input, status); | |||||
| //var updated_inputs = inputs; | |||||
| } | |||||
| private void _assert_same_graph(Tensor tensor) | |||||
| { | |||||
| //TODO: implement | |||||
| } | |||||
| /// <summary> | |||||
| /// Create and return a new TF_Output for output_idx'th output of this op. | |||||
| /// </summary> | |||||
| public TF_Output _tf_output(int output_idx) | |||||
| { | |||||
| var tf_output = new TF_Output(op, output_idx); | |||||
| return tf_output; | |||||
| } | |||||
| /// <summary> | |||||
| /// Create and return a new TF_Input for input_idx'th input of this op. | |||||
| /// </summary> | |||||
| public TF_Input _tf_input(int input_idx) | |||||
| { | |||||
| var tf_input = new TF_Input(op, input_idx); | |||||
| return tf_input; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -308,7 +308,7 @@ namespace Tensorflow | |||||
| tensor.op.graph.prevent_fetching(tensor.op); | tensor.op.graph.prevent_fetching(tensor.op); | ||||
| // Build the graph for the true branch in a new context. | // Build the graph for the true branch in a new context. | ||||
| var context_t = new CondContext(pred, pivot_1, branch: 1); | |||||
| var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | |||||
| ITensorOrOperation orig_res_t; | ITensorOrOperation orig_res_t; | ||||
| Tensor res_t; | Tensor res_t; | ||||
| try | try | ||||
| @@ -321,7 +321,7 @@ namespace Tensorflow | |||||
| context_t.Exit(); | context_t.Exit(); | ||||
| } | } | ||||
| // Build the graph for the false branch in a new context. | // Build the graph for the false branch in a new context. | ||||
| var context_f = new CondContext(pred, pivot_2, branch: 0); | |||||
| var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | |||||
| ITensorOrOperation orig_res_f; | ITensorOrOperation orig_res_f; | ||||
| Tensor res_f; | Tensor res_f; | ||||
| try | try | ||||
| @@ -389,13 +389,13 @@ namespace Tensorflow | |||||
| tensor.op.graph.prevent_fetching(tensor.op); | tensor.op.graph.prevent_fetching(tensor.op); | ||||
| // Build the graph for the true branch in a new context. | // Build the graph for the true branch in a new context. | ||||
| var context_t = new CondContext(pred, pivot_1, branch: 1); | |||||
| var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | |||||
| context_t.Enter(); | context_t.Enter(); | ||||
| var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | ||||
| context_t.Exit(); | context_t.Exit(); | ||||
| // Build the graph for the false branch in a new context. | // Build the graph for the false branch in a new context. | ||||
| var context_f = new CondContext(pred, pivot_2, branch: 0); | |||||
| var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | |||||
| context_f.Enter(); | context_f.Enter(); | ||||
| var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | ||||
| context_f.Exit(); | context_f.Exit(); | ||||
| @@ -80,7 +80,7 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor add(Tensor x, Tensor y, string name = null) | |||||
| public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | ||||
| @@ -300,7 +300,7 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor mul(Tensor x, Tensor y, string name = null) | |||||
| public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | ||||
| @@ -8,7 +8,7 @@ namespace Tensorflow | |||||
| /// In order for a object to be serialized to and from MetaGraphDef, | /// In order for a object to be serialized to and from MetaGraphDef, | ||||
| /// the class must implement to_proto() and from_proto() methods | /// the class must implement to_proto() and from_proto() methods | ||||
| /// </summary> | /// </summary> | ||||
| public interface IProtoBuf | |||||
| public interface IProtoBuf<TProtoDef, TDef> | |||||
| { | { | ||||
| string name { get; } | string name { get; } | ||||
| @@ -17,15 +17,15 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="export_scope"></param> | /// <param name="export_scope"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| VariableDef to_proto(string export_scope); | |||||
| TProtoDef to_proto(string export_scope); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns a `Variable` object created from `variable_def`. | /// Returns a `Variable` object created from `variable_def`. | ||||
| /// </summary> | /// </summary> | ||||
| /// <typeparam name="T"></typeparam> | /// <typeparam name="T"></typeparam> | ||||
| /// <param name="variable_def"></param> | |||||
| /// <param name="proto"></param> | |||||
| /// <param name="import_scope"></param> | /// <param name="import_scope"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| T from_proto<T>(VariableDef variable_def, string import_scope); | |||||
| TDef from_proto(TProtoDef proto, string import_scope); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,12 @@ | |||||
| ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | ||||
| Work in command line | |||||
| ```shell | ```shell | ||||
| cd tensorflow | |||||
| set SRC_DIR=D:/Projects/tensorflow | set SRC_DIR=D:/Projects/tensorflow | ||||
| set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf | set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf | ||||
| cd tensorflow | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto | ||||
| @@ -32,6 +34,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.prot | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto | ||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/control_flow.proto | |||||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto | ||||
| ``` | ``` | ||||
| @@ -7,7 +7,7 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class RefVariable : VariableV1, IProtoBuf | |||||
| public partial class RefVariable : VariableV1, IProtoBuf<VariableDef, RefVariable> | |||||
| { | { | ||||
| public bool _in_graph_mode = true; | public bool _in_graph_mode = true; | ||||
| public Tensor _initial_value; | public Tensor _initial_value; | ||||
| @@ -288,7 +288,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("to_proto RefVariable"); | throw new NotImplementedException("to_proto RefVariable"); | ||||
| } | } | ||||
| public T from_proto<T>(VariableDef variable_def, string import_scope) | |||||
| public RefVariable from_proto(VariableDef proto, string import_scope) | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| @@ -376,7 +376,7 @@ namespace Tensorflow | |||||
| if (import_scope.EndsWith("/")) | if (import_scope.EndsWith("/")) | ||||
| import_scope = import_scope.Substring(0, import_scope.Length - 1); | import_scope = import_scope.Substring(0, import_scope.Length - 1); | ||||
| throw new NotImplementedException("prepend_name_scope"); | |||||
| return $"{import_scope}/{name}"; | |||||
| } | } | ||||
| else | else | ||||
| return name; | return name; | ||||
| @@ -132,10 +132,11 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Evaluates tensors and returns a dictionary of {name:result, ...}. | |||||
| /// <param name="tensors">A Tensor or a nested list/tuple of Tensors.</param> | |||||
| /// This function is used in many original tensorflow unit tests to evaluate tensors | |||||
| /// in a test session with special settings (for instance constant folding off) | |||||
| /// | |||||
| /// </summary> | /// </summary> | ||||
| public Dictionary<string, NDArray> evaluate(params Tensor[] tensors) | |||||
| public T evaluate<T>(Tensor tensor) | |||||
| { | { | ||||
| var results = new Dictionary<string, NDArray>(); | var results = new Dictionary<string, NDArray>(); | ||||
| // if context.executing_eagerly(): | // if context.executing_eagerly(): | ||||
| @@ -145,49 +146,26 @@ namespace TensorFlowNET.UnitTest | |||||
| var sess = ops.get_default_session(); | var sess = ops.get_default_session(); | ||||
| if (sess == null) | if (sess == null) | ||||
| sess = self.session(); | sess = self.session(); | ||||
| with<Session>(sess, s => | |||||
| { | |||||
| foreach (var t in tensors) | |||||
| results[t.name] = t.eval(); | |||||
| }); | |||||
| return results; | |||||
| } | |||||
| } | |||||
| public NDArray evaluate(Tensor tensor) | |||||
| { | |||||
| NDArray result = null; | |||||
| // if context.executing_eagerly(): | |||||
| // return self._eval_helper(tensors) | |||||
| // else: | |||||
| { | |||||
| var sess = ops.get_default_session(); | |||||
| if (sess == null) | |||||
| sess = self.session(); | |||||
| with<Session>(sess, s => | |||||
| { | |||||
| result = tensor.eval(); | |||||
| }); | |||||
| return result; | |||||
| } | |||||
| } | |||||
| public object eval_scalar(Tensor tensor) | |||||
| { | |||||
| NDArray result = null; | |||||
| // if context.executing_eagerly(): | |||||
| // return self._eval_helper(tensors) | |||||
| // else: | |||||
| { | |||||
| var sess = ops.get_default_session(); | |||||
| if (sess == null) | |||||
| sess = self.session(); | |||||
| T t_result = (T)(object)null; | |||||
| with<Session>(sess, s => | with<Session>(sess, s => | ||||
| { | { | ||||
| result = tensor.eval(); | |||||
| var ndarray=tensor.eval(); | |||||
| if (typeof(T) == typeof(double)) | |||||
| { | |||||
| double d = ndarray; | |||||
| t_result = (T)(object)d; | |||||
| } | |||||
| else if (typeof(T) == typeof(int)) | |||||
| { | |||||
| int d = ndarray; | |||||
| t_result = (T) (object) d; | |||||
| } | |||||
| else | |||||
| { | |||||
| t_result = (T)(object)ndarray; | |||||
| } | |||||
| }); | }); | ||||
| return result.Array.GetValue(0); | |||||
| return t_result; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | namespace TensorFlowNET.UnitTest.control_flow_ops_test | ||||
| @@ -9,32 +10,73 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
| [TestClass] | [TestClass] | ||||
| public class CondTestCases : PythonTest | public class CondTestCases : PythonTest | ||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void testCondTrue() | public void testCondTrue() | ||||
| { | { | ||||
| with(tf.Graph().as_default(), g => | |||||
| var graph = tf.Graph().as_default(); | |||||
| with(tf.Session(graph), sess => | |||||
| { | { | ||||
| var x = tf.constant(2); | var x = tf.constant(2); | ||||
| var y = tf.constant(5); | var y = tf.constant(5); | ||||
| var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), | |||||
| () => tf.add(y, tf.constant(23))); | |||||
| //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
| self.assertEquals(eval_scalar(z), 34); | |||||
| var pred = tf.less(x, y); | |||||
| Func<ITensorOrOperation> if_true = delegate | |||||
| { | |||||
| return tf.multiply(x, 17); | |||||
| }; | |||||
| Func<ITensorOrOperation> if_false = delegate | |||||
| { | |||||
| return tf.add(y, 23); | |||||
| }; | |||||
| var z = control_flow_ops.cond(pred, if_true, if_false); | |||||
| int result = z.eval(sess); | |||||
| assertEquals(result, 34); | |||||
| }); | }); | ||||
| } | } | ||||
| //[Ignore("This Test Fails due to missing edges in the graph!")] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void testCondFalse() | public void testCondFalse() | ||||
| { | { | ||||
| with(tf.Graph().as_default(), g => | |||||
| /* python | |||||
| * import tensorflow as tf | |||||
| from tensorflow.python.framework import ops | |||||
| def if_true(): | |||||
| return tf.math.multiply(x, 17) | |||||
| def if_false(): | |||||
| return tf.math.add(y, 23) | |||||
| with tf.Session() as sess: | |||||
| x = tf.constant(2) | |||||
| y = tf.constant(1) | |||||
| pred = tf.math.less(x,y) | |||||
| z = tf.cond(pred, if_true, if_false) | |||||
| result = z.eval() | |||||
| print(result == 24) */ | |||||
| with(tf.Session(), sess => | |||||
| { | { | ||||
| var x = tf.constant(2); | var x = tf.constant(2); | ||||
| var y = tf.constant(1); | var y = tf.constant(1); | ||||
| var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), | |||||
| () => tf.add(y, tf.constant(23))); | |||||
| self.assertEquals(eval_scalar(z), 24); | |||||
| var pred = tf.less(x, y); | |||||
| Func<ITensorOrOperation> if_true = delegate | |||||
| { | |||||
| return tf.multiply(x, 17); | |||||
| }; | |||||
| Func<ITensorOrOperation> if_false = delegate | |||||
| { | |||||
| return tf.add(y, 23); | |||||
| }; | |||||
| var z = control_flow_ops.cond(pred, if_true, if_false); | |||||
| int result = z.eval(sess); | |||||
| assertEquals(result, 24); | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| { | { | ||||
| var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); | var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); | ||||
| }); | }); | ||||
| tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
| //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
| assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); | assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); | ||||
| assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); | assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); | ||||
| } | } | ||||