diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs new file mode 100644 index 00000000..2ce79a3e --- /dev/null +++ b/src/TensorFlowNET.Core/Device/c_api.device.cs @@ -0,0 +1,32 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + public partial class c_api + { + /// + /// Specify the device for `desc`. Defaults to empty, meaning unconstrained. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_SetDevice(IntPtr desc, string device); + } +} diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs index 55e321df..63285bae 100644 --- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs +++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs @@ -69,7 +69,9 @@ namespace Tensorflow _new_stack = false; } - _seen_nodes = new List(); + _seen_nodes = new List(); + _old_stack = null; + _old_control_flow_context = null; } public void add_op(ITensorOrOperation op) diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 57311e8b..9b42eaaa 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -139,7 +139,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { Tensor outputs = null; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index 6a7c58cc..ad233d6b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 212035cb..74778873 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { Tensor outputs = null; var rank = inputs.rank; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index f15c01b8..95544d36 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -50,7 +50,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 46d45862..d7d7e31a 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -52,6 +52,7 @@ namespace Tensorflow.Keras.Layers protected InputSpec input_spec; protected bool supports_masking; protected List _trainable_weights; + protected List _non_trainable_weights; private string _name; public string name => _name; protected string _base_name; @@ -84,6 +85,7 @@ namespace Tensorflow.Keras.Layers _init_set_name(name); _trainable_weights = new List(); + _non_trainable_weights = new List(); _compute_previous_mask = false; _updates = new List(); @@ -103,6 +105,7 @@ namespace Tensorflow.Keras.Layers public (Tensor, Tensor) __call__(Tensor[] inputs, Tensor training = null, + Tensor state = null, VariableScope scope = null) { var input_list = inputs; @@ -139,7 +142,9 @@ namespace Tensorflow.Keras.Layers // overridden). _maybe_build(inputs[0]); - (input, outputs) = call(inputs[0], training: training); + (input, outputs) = call(inputs[0], + training: training, + state: state); (input, outputs) = _set_connectivity_metadata_(input, outputs); _handle_activity_regularization(inputs[0], outputs); _set_mask_metadata(inputs[0], outputs, null); @@ -173,7 +178,7 @@ namespace Tensorflow.Keras.Layers return null; } - protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { return (inputs, inputs); } @@ -233,7 +238,10 @@ namespace Tensorflow.Keras.Layers initializer: initializer, trainable: trainable.Value); //backend.track_variable(variable); - _trainable_weights.Add(variable); + if (trainable == true) + _trainable_weights.Add(variable); + else + _non_trainable_weights.Add(variable); return variable; } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index e9008543..81d57abe 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { int[] pool_shape; if (data_format == "channels_last") diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 2ea427c3..d7cda786 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -43,6 +43,7 @@ namespace Tensorflow.Layers // Avoid an incorrect lint error _trainable_weights = new List(); + _non_trainable_weights = new List(); this.built = false; _keras_style = false; } @@ -54,6 +55,7 @@ namespace Tensorflow.Layers public (Tensor, Tensor) __call__(Tensor inputs, Tensor training = null, + Tensor state = null, VariableScope scope = null) { _set_scope(scope); @@ -76,7 +78,9 @@ namespace Tensorflow.Layers { _current_scope = scope2; // Actually call layer - outputs = base.__call__(new Tensor[] { inputs }, training: training); + outputs = base.__call__(new Tensor[] { inputs }, + training: training, + state: state); }); @@ -121,6 +125,11 @@ namespace Tensorflow.Layers Graph init_graph = null; VariableV1[] existing_variables = null; + if (synchronization == VariableSynchronization.OnRead) + trainable = false; + else if (!trainable.HasValue) + trainable = true; + if (default_graph.building_function) { throw new NotImplementedException("add_weight"); diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs index 9911212b..fdcc03ea 100644 --- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs @@ -66,7 +66,7 @@ namespace Tensorflow built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor state = null) + protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) { // Most basic RNN: output = new_state = act(W * input + U * state + B). var concat = array_ops.concat(new[] { inputs, state }, 1); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 1d071856..715c68c6 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -307,12 +307,6 @@ namespace Tensorflow.Operations protected override void _AddOpInternal(Operation op) { - if(op.name == "rnn/while/basic_rnn_cell/MatMul" || - op.name == "rnn/while/TensorArrayReadV3") - { - - } - Operation[] external_inputs = new Operation[0]; if (op.inputs.Length == 0) { @@ -412,10 +406,12 @@ namespace Tensorflow.Operations } if (_outer_context != null) - { result = _outer_context.AddValue(val); - } + if (tf.get_default_graph()._nodes_by_name.Count >= 83) + { + + } // Create an Enter to make `result` known to this loop context. Tensor enter = null; tf_with(ops.control_dependencies(null), delegate diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index e1ac0204..636b1451 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -16,6 +16,7 @@ using System; using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow.Operations.Initializers { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 41a4622a..a8a0e0b9 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -214,7 +214,7 @@ namespace Tensorflow.Operations if (sequence_length != null) throw new NotImplementedException("sequence_length != null"); else - a = cell.__call__(input_t_t, state1); + a = cell.__call__(input_t_t, state: state1); return item; }; diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 8e660797..9f0cb9a5 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -32,9 +32,7 @@ namespace Tensorflow public void _control_flow_post_processing() { foreach(Tensor input_tensor in inputs) - { control_flow_util.CheckInputFromValidContext(this, input_tensor.op); - } if (_control_flow_context != null) _control_flow_context.AddOp(this); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 8bdaaa7b..d5068f2e 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -78,6 +78,7 @@ namespace Tensorflow #if SERIALIZABLE [JsonIgnore] #endif + bool _is_stateful; public NodeDef node_def { get @@ -173,6 +174,8 @@ namespace Tensorflow } } + _id_value = _graph._next_id(); + // Dict mapping op name to file and line information for op colocation // context managers. _control_flow_context = graph._get_control_flow_context(); @@ -184,6 +187,8 @@ namespace Tensorflow var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); + _is_stateful = op_def.IsStateful; + // Initialize self._outputs. output_types = new TF_DataType[NumOutputs]; for (int i = 0; i < NumOutputs; i++) diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs index 9251f867..be4aef55 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs @@ -71,7 +71,7 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => { name = scope; - var tensorShape = _ShapeTensor(shape); + var tensorShape = tensor_util.shape_tensor(shape); var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); var rnd = gen_random_ops.random_uniform(tensorShape, dtype); diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 142afe06..0989db4f 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -335,5 +335,10 @@ namespace Tensorflow return shape; } + + public static Tensor shape_tensor(int[] shape) + { + return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape"); + } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 4b0a35fb..c79c5b7f 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -133,66 +133,69 @@ namespace Tensorflow if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); - ops.init_scope(); - var values = init_from_fn ? new object[0] : new object[] { initial_value }; - tf_with(ops.name_scope(name, "Variable", values), scope => + tf_with(ops.init_scope2(), delegate { - name = scope; - if (init_from_fn) + var values = init_from_fn ? new object[0] : new object[] { initial_value }; + tf_with(ops.name_scope(name, "Variable", values), scope => { - // Use attr_scope and device(None) to simulate the behavior of - // colocate_with when the variable we want to colocate with doesn't - // yet exist. - string true_name = ops.name_from_scope_name(name); - var attr = new AttrValue + name = scope; + + if (init_from_fn) { - List = new AttrValue.Types.ListValue() - }; - attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); - tf_with(ops.name_scope("Initializer"), scope2 => + // Use attr_scope and device(None) to simulate the behavior of + // colocate_with when the variable we want to colocate with doesn't + // yet exist. + string true_name = ops.name_from_scope_name(name); + var attr = new AttrValue + { + List = new AttrValue.Types.ListValue() + }; + attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); + tf_with(ops.name_scope("Initializer"), scope2 => + { + _initial_value = (initial_value as Func)(); + _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); + }); + _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); + } + // Or get the initial value from a Tensor or Python object. + else { - _initial_value = (initial_value as Func)(); - _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); - }); - _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); - } - // Or get the initial value from a Tensor or Python object. - else - { - _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); + _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); - var shape = _initial_value.shape; - dtype = _initial_value.dtype; - _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); - } + var shape = _initial_value.shape; + dtype = _initial_value.dtype; + _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); + } - // Manually overrides the variable's shape with the initial value's. - if (validate_shape) - { - var initial_value_shape = _initial_value.TensorShape; - if (!initial_value_shape.is_fully_defined()) - throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); - } + // Manually overrides the variable's shape with the initial value's. + if (validate_shape) + { + var initial_value_shape = _initial_value.TensorShape; + if (!initial_value_shape.is_fully_defined()) + throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); + } - // If 'initial_value' makes use of other variables, make sure we don't - // have an issue if these other variables aren't initialized first by - // using their initialized_value() method. - var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); + // If 'initial_value' makes use of other variables, make sure we don't + // have an issue if these other variables aren't initialized first by + // using their initialized_value() method. + var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); - _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; + _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; - if (!String.IsNullOrEmpty(caching_device)) - { + if (!String.IsNullOrEmpty(caching_device)) + { - } - else - { - ops.colocate_with(_initializer_op); + } + else + { + ops.colocate_with(_initializer_op); - _snapshot = gen_array_ops.identity(_variable, name = "read"); - } + _snapshot = gen_array_ops.identity(_variable, name = "read"); + } - ops.add_to_collections(collections, this as VariableV1); + ops.add_to_collections(collections, this as VariableV1); + }); }); } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 3549b07e..02417594 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -186,12 +186,7 @@ namespace Tensorflow /// operations constructed within the context. /// public static _ControlDependenciesController control_dependencies(object[] control_inputs) - { - return get_default_graph().control_dependencies(control_inputs); - } - - public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) - => control_dependencies(control_inputs == null ? null : control_inputs.OfType().ToArray()); + => get_default_graph().control_dependencies(control_inputs); /// /// Creates a TF_Operation. @@ -212,9 +207,9 @@ namespace Tensorflow { var op_desc = graph.NewOperation(node_def.Op, node_def.Name); - //TODO: Implement TF_SetDevice - //if node_def.device: - // c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) + if (!string.IsNullOrEmpty(node_def.Device)) + c_api.TF_SetDevice(op_desc, node_def.Device); + // Add inputs foreach (var op_input in inputs) { @@ -310,6 +305,22 @@ namespace Tensorflow }); } + public static IObjectLife init_scope2() + { + // Retrieve the active name scope: entering an `init_scope` preserves + // the name scope of the current context. + var default_graph = get_default_graph(); + var scope = default_graph.get_name_scope(); + if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) + // Names that end with trailing slashes are treated by `name_scope` as + // absolute. + scope += "/"; + // inner_device_stack = default_graph._device_function_stack + // var outer_context = default_graph.as_default; + + return ops.control_dependencies(null); + } + private static int uid_number = 0; ///