From ded16ea82f492de35892d308f1accce735ad05e3 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 27 Oct 2019 13:34:59 -0500 Subject: [PATCH] BatchNormalization return tuple for call --- src/TensorFlowNET.Core/APIs/tf.layers.cs | 8 ++--- src/TensorFlowNET.Core/APIs/tf.nn.cs | 2 +- .../Graphs/Graph.Control.cs | 1 + src/TensorFlowNET.Core/Graphs/Graph.cs | 14 ++++---- .../Graphs/_ControlDependenciesController.cs | 1 + .../Interfaces/IPackable.cs | 4 +-- .../Keras/Layers/BatchNormalization.cs | 4 +-- src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 6 ++-- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 6 ++-- .../Keras/Layers/Embedding.cs | 4 +-- src/TensorFlowNET.Core/Keras/Layers/Layer.cs | 10 +++--- .../Keras/Layers/Pooling2D.cs | 4 +-- src/TensorFlowNET.Core/Layers/Layer.cs | 6 ++-- .../Operations/BasicRNNCell.cs | 33 +++++++++++++++++-- .../ControlFlows/ControlFlowContext.cs | 21 ------------ .../Operations/ControlFlows/LoopVar.cs | 12 ++++--- .../Operations/ControlFlows/WhileContext.cs | 23 ++++++++++--- .../Operations/LayerRNNCell.cs | 4 +-- .../Operations/NnOps/rnn_cell_impl.cs | 4 +-- .../Operations/OpDefLibrary.cs | 14 ++++++++ .../Operations/Operation.Control.cs | 1 + .../Operations/Operation.Input.cs | 10 +++--- .../Operations/Operation.cs | 6 ++-- .../Operations/control_flow_ops.cs | 31 +++++++++++++++++ .../Operations/gen_math_ops.cs | 2 +- src/TensorFlowNET.Core/Operations/math_ops.cs | 17 ++++++++++ src/TensorFlowNET.Core/Util/nest.py.cs | 7 ++-- src/TensorFlowNET.Core/ops.cs | 2 ++ src/TensorFlowNET.Core/ops.name_scope.cs | 3 ++ 29 files changed, 176 insertions(+), 84 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 9f989bc5..25448441 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -63,7 +63,7 @@ namespace Tensorflow trainable: trainable, name: name); - return layer.apply(inputs); + return layer.apply(inputs).Item1; } /// @@ -117,7 +117,7 @@ namespace Tensorflow trainable: trainable, name: name); - return layer.apply(inputs, training: training); + return layer.apply(inputs, training: training).Item1; } /// @@ -143,7 +143,7 @@ namespace Tensorflow data_format: data_format, name: name); - return layer.apply(inputs); + return layer.apply(inputs).Item1; } /// @@ -179,7 +179,7 @@ namespace Tensorflow kernel_initializer: kernel_initializer, trainable: trainable); - return layer.apply(inputs); + return layer.apply(inputs).Item1; } /// diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index e9805010..5b5786d1 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -76,7 +76,7 @@ namespace Tensorflow /// /// /// A pair (outputs, state) - public (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, + public (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index c97e1b6f..c6a5dee0 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -18,6 +18,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 87a1424f..c9ad6402 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -262,15 +262,11 @@ namespace Tensorflow if (string.IsNullOrEmpty(name)) name = op_type; + // If a names ends with a '/' it is a "name scope" and we use it as-is, // after removing the trailing '/'. name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); - - if (name.Contains("define_loss/bigger_box_loss/mul_13")) - { - - } var input_ops = inputs.Select(x => x.op).ToArray(); var control_inputs = _control_dependencies_for_inputs(input_ops); @@ -377,7 +373,11 @@ namespace Tensorflow /// A string to be passed to `create_op()` that will be used /// to name the operation being created. public string unique_name(string name, bool mark_as_used = true) - { + { + if (name.EndsWith("basic_r_n_n_cell")) + { + + } if (!String.IsNullOrEmpty(_name_stack)) name = _name_stack + "/" + name; // For the sake of checking for names in use, we treat names as case @@ -405,7 +405,7 @@ namespace Tensorflow // Return the new name with the original capitalization of the given name. name = $"{name}_{i-1}"; - } + } return name; } diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs index 6a75c982..55e321df 100644 --- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs +++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs @@ -16,6 +16,7 @@ using System.Collections.Generic; using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Interfaces/IPackable.cs b/src/TensorFlowNET.Core/Interfaces/IPackable.cs index 86ceabc7..94e31ece 100644 --- a/src/TensorFlowNET.Core/Interfaces/IPackable.cs +++ b/src/TensorFlowNET.Core/Interfaces/IPackable.cs @@ -4,8 +4,8 @@ using System.Text; namespace Tensorflow { - public interface IPackable + public interface IPackable { - void Pack(object[] sequences); + T Pack(object[] sequences); } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 0428b2ad..57311e8b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -139,14 +139,14 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { Tensor outputs = null; if (fused) { outputs = _fused_batch_norm(inputs, training: training); - return outputs; + return (outputs, outputs); } throw new NotImplementedException("BatchNormalization call"); diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index dc40ae8c..6a7c58cc 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) @@ -124,9 +124,9 @@ namespace Tensorflow.Keras.Layers } if (activation != null) - return activation.Activate(outputs); + outputs = activation.Activate(outputs); - return outputs; + return (outputs, outputs); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 2564da6d..212035cb 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { Tensor outputs = null; var rank = inputs.rank; @@ -88,9 +88,9 @@ namespace Tensorflow.Keras.Layers if (use_bias) outputs = tf.nn.bias_add(outputs, bias); if (activation != null) - return activation.Activate(outputs); + outputs = activation.Activate(outputs); - return outputs; + return (outputs, outputs); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index f10499c4..f15c01b8 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -50,14 +50,14 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) inputs = math_ops.cast(inputs, tf.int32); var @out = embedding_ops.embedding_lookup(embeddings, inputs); - return @out; + return (@out, @out); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 25161721..46d45862 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -101,7 +101,7 @@ namespace Tensorflow.Keras.Layers _inbound_nodes = new List(); } - public Tensor __call__(Tensor[] inputs, + public (Tensor, Tensor) __call__(Tensor[] inputs, Tensor training = null, VariableScope scope = null) { @@ -139,14 +139,14 @@ namespace Tensorflow.Keras.Layers // overridden). _maybe_build(inputs[0]); - outputs = call(inputs[0], training: training); + (input, outputs) = call(inputs[0], training: training); (input, outputs) = _set_connectivity_metadata_(input, outputs); _handle_activity_regularization(inputs[0], outputs); _set_mask_metadata(inputs[0], outputs, null); }); } - return outputs; + return (input, outputs); } private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs) @@ -173,9 +173,9 @@ namespace Tensorflow.Keras.Layers return null; } - protected virtual Tensor call(Tensor inputs, Tensor training = null) + protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { - return inputs; + return (inputs, inputs); } protected virtual string _name_scope() diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index 9774750a..e9008543 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override Tensor call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) { int[] pool_shape; if (data_format == "channels_last") @@ -64,7 +64,7 @@ namespace Tensorflow.Keras.Layers padding: padding.ToUpper(), data_format: conv_utils.convert_data_format(data_format, 4)); - return outputs; + return (outputs, outputs); } } } diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 138f0fc7..2ea427c3 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -47,12 +47,12 @@ namespace Tensorflow.Layers _keras_style = false; } - public virtual Tensor apply(Tensor inputs, Tensor training = null) + public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null) { return __call__(inputs, training: training); } - public Tensor __call__(Tensor inputs, + public (Tensor, Tensor) __call__(Tensor inputs, Tensor training = null, VariableScope scope = null) { @@ -71,7 +71,7 @@ namespace Tensorflow.Layers auxiliary_name_scope: false); } - Tensor outputs = null; + (Tensor, Tensor) outputs = (null, null); tf_with(scope_context_manager, scope2 => { _current_scope = scope2; diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs index 554e9f1a..9911212b 100644 --- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs @@ -16,18 +16,23 @@ using System; using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; namespace Tensorflow { - public class BasicRNNCell : LayerRNNCell + public class BasicRnnCell : LayerRnnCell { int _num_units; Func _activation; public override int state_size => _num_units; public override int output_size => _num_units; + public VariableV1 _kernel; + string _WEIGHTS_VARIABLE_NAME = "kernel"; + public VariableV1 _bias; + string _BIAS_VARIABLE_NAME = "bias"; - public BasicRNNCell(int num_units, + public BasicRnnCell(int num_units, Func activation = null, bool? reuse = null, string name = null, @@ -44,5 +49,29 @@ namespace Tensorflow else _activation = activation; } + + protected override void build(TensorShape inputs_shape) + { + var input_depth = inputs_shape.dims[inputs_shape.ndim - 1]; + + _kernel = add_weight( + _WEIGHTS_VARIABLE_NAME, + shape: new[] { input_depth + _num_units, _num_units }); + + _bias = add_weight( + _BIAS_VARIABLE_NAME, + shape: new[] { _num_units }, + initializer: tf.zeros_initializer); + + built = true; + } + + protected override (Tensor, Tensor) call(Tensor inputs, Tensor state = null) + { + // Most basic RNN: output = new_state = act(W * input + U * state + B). + var concat = array_ops.concat(new[] { inputs, state }, 1); + var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable); + return (inputs, inputs); + } } } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 953fd6c7..97f244c4 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -136,27 +136,6 @@ namespace Tensorflow.Operations graph._set_control_flow_context(this); } - protected virtual Tensor _Enter(Tensor data, string frame_name, - bool is_constant = false, - int parallel_iterations = 10, - bool use_ref = true, - bool use_input_shape = true, - string name = null) - { - Tensor result; - data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); - if (data.dtype.is_ref_dtype() && use_ref) - throw new NotImplementedException("_Enter"); - else - result = gen_control_flow_ops.enter( - data, frame_name, is_constant, parallel_iterations, name: name); - - if (use_input_shape) - result.set_shape(data.TensorShape); - - return result; - } - /// /// Exit this control flow context. /// diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs index 845ff494..5359190c 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow.Operations { - internal class LoopVar : ICanBeFlattened, IPackable + internal class LoopVar : ICanBeFlattened, IPackable> { public Tensor Counter { get; set; } public TItem Item { get; set; } @@ -26,11 +26,13 @@ namespace Tensorflow.Operations return elements.ToArray(); } - public void Pack(object[] sequences) + public LoopVar Pack(object[] sequences) { - Counter = sequences[0] as Tensor; - if (typeof(TItem).GetInterface(typeof(IPackable).Name) != null) - (Item as IPackable).Pack(sequences.Skip(1).ToArray()); + var counter = sequences[0] as Tensor; + var item = default(TItem); + if (typeof(TItem).GetInterface(typeof(IPackable).Name) != null) + item = (Item as IPackable).Pack(sequences.Skip(1).ToArray()); + return new LoopVar(counter, item); } public static implicit operator (Tensor, TItem)(LoopVar loopVar) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 55802fe7..1d071856 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -184,7 +184,7 @@ namespace Tensorflow.Operations Tensor[] enter_vars = null; tf_with(ops.control_dependencies(null), delegate { - enter_vars = real_vars.Select(x => _Enter(x, + enter_vars = real_vars.Select(x => control_flow_ops._Enter(x, _name, is_constant: false, parallel_iterations: _parallel_iterations, @@ -294,6 +294,10 @@ namespace Tensorflow.Operations } } + /// + /// Makes the values known to this context. + /// + /// private void _InitializeValues(Tensor[] values) { _values = new HashSet(); @@ -303,8 +307,14 @@ namespace Tensorflow.Operations protected override void _AddOpInternal(Operation op) { + if(op.name == "rnn/while/basic_rnn_cell/MatMul" || + op.name == "rnn/while/TensorArrayReadV3") + { + + } + Operation[] external_inputs = new Operation[0]; - if (op == null) + if (op.inputs.Length == 0) { throw new NotImplementedException(""); } @@ -374,6 +384,11 @@ namespace Tensorflow.Operations _AddOpInternal(op); } + /// + /// Add `val` to the current context and its outer context recursively. + /// + /// + /// public override Tensor AddValue(Tensor val) { var result = val; @@ -403,9 +418,9 @@ namespace Tensorflow.Operations // Create an Enter to make `result` known to this loop context. Tensor enter = null; - tf_with(ops.control_dependencies(new ITensorOrOperation[0]), delegate + tf_with(ops.control_dependencies(null), delegate { - enter = _Enter( + enter = control_flow_ops._Enter( result, _name, is_constant: true, diff --git a/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs b/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs index ca9c31bb..16aa147c 100644 --- a/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs @@ -16,9 +16,9 @@ namespace Tensorflow { - public class LayerRNNCell : RNNCell + public class LayerRnnCell : RnnCell { - public LayerRNNCell(bool? _reuse = null, + public LayerRnnCell(bool? _reuse = null, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse, name: name, diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs index c3d9cbdf..3164ba14 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs @@ -20,8 +20,8 @@ namespace Tensorflow.Operations { public class rnn_cell_impl { - public BasicRNNCell BasicRNNCell(int num_units) - => new BasicRNNCell(num_units); + public BasicRnnCell BasicRNNCell(int num_units) + => new BasicRnnCell(num_units); public static Tensor _concat(Tensor prefix, int suffix, bool @static = false) { diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 89ddebdb..5700ccdd 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -228,6 +228,15 @@ namespace Tensorflow output_types.AddRange(types); } + // We add an explicit colocation constraint between + // the newly created op and any of its reference-typed inputs. + var must_colocate_inputs = zip(op_def.InputArg, inputs) + .Where(x => x.Item1.IsRef) + .Select(x => x.Item2) + .ToArray(); + + _MaybeColocateWith(must_colocate_inputs); + // Add Op to graph var op = g.create_op(op_type_name, inputs.ToArray(), @@ -241,6 +250,11 @@ namespace Tensorflow }); } + private void _MaybeColocateWith(ITensorOrOperation[] inputs) + { + + } + private void SetAttrs(string op_type_name, ArgDef input_arg, OpDef op_def, diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 5e93cfd0..8e660797 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 57ac8271..af3c57b2 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -44,14 +44,14 @@ namespace Tensorflow [JsonIgnore] #endif public int NumInputs => c_api.TF_OperationNumInputs(_handle); - private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); + private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray(); - private InputList _inputs; + private InputList _inputs_val; public InputList inputs { get { - if (_inputs == null) + if (_inputs_val == null) { var retval = new Tensor[NumInputs]; @@ -62,10 +62,10 @@ namespace Tensorflow retval[i] = op.outputs[tf_output.index]; } - _inputs = new InputList(retval); + _inputs_val = new InputList(retval); } - return _inputs; + return _inputs_val; } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index db001e51..8bdaaa7b 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -175,8 +175,8 @@ namespace Tensorflow // Dict mapping op name to file and line information for op colocation // context managers. - _control_flow_context = graph._get_control_flow_context(); - + _control_flow_context = graph._get_control_flow_context(); + // This will be set by self.inputs. if (op_def == null) op_def = g.GetOpDef(node_def.Op); @@ -305,7 +305,7 @@ namespace Tensorflow var output = tensor._as_tf_output(); // Reset cached inputs. - _inputs = null; + _inputs_val = null; // after the c_api call next time _inputs is accessed // the updated inputs are reloaded from the c_api lock (Locks.ProcessWide) diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 1229c6b7..13182dfd 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -675,5 +675,36 @@ namespace Tensorflow throw new NotImplementedException("while_loop"); } + /// + /// Creates or finds a child frame, and makes `data` available to it. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor _Enter(Tensor data, string frame_name, + bool is_constant = false, + int parallel_iterations = 10, + bool use_ref = true, + bool use_input_shape = true, + string name = null) + { + Tensor result; + data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); + if (data.dtype.is_ref_dtype() && use_ref) + throw new NotImplementedException("_Enter"); + else + result = gen_control_flow_ops.enter( + data, frame_name, is_constant, parallel_iterations, name: name); + + if (use_input_shape) + result.set_shape(data.TensorShape); + + return result; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index e1225cc9..08431089 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -568,7 +568,7 @@ namespace Tensorflow { var _op = _op_def_lib._apply_op_helper("MatMul", name, args: new { a, b, transpose_a, transpose_b }); - return _op.outputs[0]; + return _op.output; } /// diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index d4dfc12b..17cd8a99 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -543,6 +543,23 @@ namespace Tensorflow public static Tensor maximum(Tx x, Ty y, string name = null) => gen_math_ops.maximum(x, y, name: name); + /// + /// Multiplies matrix `a` by matrix `b`, producing `a` * `b`. + /// + /// + /// + /// If `True`, `a` is transposed before multiplication. + /// If `True`, `b` is transposed before multiplication. + /// If `True`, `a` is conjugated and transposed before multiplication. + /// If `True`, `b` is conjugated and transposed before multiplication. + /// If `True`, `a` is treated as a sparse matrix. + /// If `True`, `b` is treated as a sparse matrix. + /// Name for the operation (optional). + /// + /// A `Tensor` of the same type as `a` and `b` where each inner-most matrix is + /// the product of the corresponding matrices in `a` and `b`, e.g. if all + /// transpose or adjoint attributes are `False`: + /// public static Tensor matmul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, bool adjoint_a = false, bool adjoint_b = false, diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 97980203..54149fe1 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -402,11 +402,8 @@ namespace Tensorflow.Util private static int len(IEnumerable x) => x.Count(); public static T pack_sequence_as2(T structure, object[] flat_sequence, bool expand_composites = false) - where T : IPackable - { - structure.Pack(flat_sequence); - return structure; - } + where T : IPackable + => structure.Pack(flat_sequence); /// /// Returns a given flattened sequence packed into a given structure. diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index d1e423c9..3549b07e 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -508,6 +508,8 @@ namespace Tensorflow return null; case TensorShape ts: return constant_op.constant(ts.dims, dtype: dtype, name: name); + case int[] dims: + return constant_op.constant(dims, dtype: dtype, name: name); case object[] objects: return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name); default: diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index bd98f2ca..80397667 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -45,7 +45,10 @@ namespace Tensorflow public void __enter__() { _name = _name ?? _default_name; + if (_name.EndsWith("basic_r_n_n_cell")) + { + } Graph g = null; if (_values is List vList)