diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs
new file mode 100644
index 00000000..2ce79a3e
--- /dev/null
+++ b/src/TensorFlowNET.Core/Device/c_api.device.cs
@@ -0,0 +1,32 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Runtime.InteropServices;
+
+namespace Tensorflow
+{
+ public partial class c_api
+ {
+ ///
+ /// Specify the device for `desc`. Defaults to empty, meaning unconstrained.
+ ///
+ ///
+ ///
+ [DllImport(TensorFlowLibName)]
+ public static extern void TF_SetDevice(IntPtr desc, string device);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
index 55e321df..63285bae 100644
--- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
+++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
@@ -69,7 +69,9 @@ namespace Tensorflow
_new_stack = false;
}
- _seen_nodes = new List();
+ _seen_nodes = new List();
+ _old_stack = null;
+ _old_control_flow_context = null;
}
public void add_op(ITensorOrOperation op)
diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
index 57311e8b..9b42eaaa 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
@@ -139,7 +139,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}
- protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
+ protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
Tensor outputs = null;
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs
index 6a7c58cc..ad233d6b 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs
@@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}
- protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
+ protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
var outputs = _convolution_op.__call__(inputs, kernel);
if (use_bias)
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs
index 212035cb..74778873 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs
@@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}
- protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
+ protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
Tensor outputs = null;
var rank = inputs.rank;
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
index f15c01b8..95544d36 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
@@ -50,7 +50,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}
- protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
+ protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
index 46d45862..d7d7e31a 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
@@ -52,6 +52,7 @@ namespace Tensorflow.Keras.Layers
protected InputSpec input_spec;
protected bool supports_masking;
protected List _trainable_weights;
+ protected List _non_trainable_weights;
private string _name;
public string name => _name;
protected string _base_name;
@@ -84,6 +85,7 @@ namespace Tensorflow.Keras.Layers
_init_set_name(name);
_trainable_weights = new List();
+ _non_trainable_weights = new List();
_compute_previous_mask = false;
_updates = new List();
@@ -103,6 +105,7 @@ namespace Tensorflow.Keras.Layers
public (Tensor, Tensor) __call__(Tensor[] inputs,
Tensor training = null,
+ Tensor state = null,
VariableScope scope = null)
{
var input_list = inputs;
@@ -139,7 +142,9 @@ namespace Tensorflow.Keras.Layers
// overridden).
_maybe_build(inputs[0]);
- (input, outputs) = call(inputs[0], training: training);
+ (input, outputs) = call(inputs[0],
+ training: training,
+ state: state);
(input, outputs) = _set_connectivity_metadata_(input, outputs);
_handle_activity_regularization(inputs[0], outputs);
_set_mask_metadata(inputs[0], outputs, null);
@@ -173,7 +178,7 @@ namespace Tensorflow.Keras.Layers
return null;
}
- protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
+ protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
return (inputs, inputs);
}
@@ -233,7 +238,10 @@ namespace Tensorflow.Keras.Layers
initializer: initializer,
trainable: trainable.Value);
//backend.track_variable(variable);
- _trainable_weights.Add(variable);
+ if (trainable == true)
+ _trainable_weights.Add(variable);
+ else
+ _non_trainable_weights.Add(variable);
return variable;
}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
index e9008543..81d57abe 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
@@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4);
}
- protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
+ protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
int[] pool_shape;
if (data_format == "channels_last")
diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs
index 2ea427c3..d7cda786 100644
--- a/src/TensorFlowNET.Core/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Layers/Layer.cs
@@ -43,6 +43,7 @@ namespace Tensorflow.Layers
// Avoid an incorrect lint error
_trainable_weights = new List();
+ _non_trainable_weights = new List();
this.built = false;
_keras_style = false;
}
@@ -54,6 +55,7 @@ namespace Tensorflow.Layers
public (Tensor, Tensor) __call__(Tensor inputs,
Tensor training = null,
+ Tensor state = null,
VariableScope scope = null)
{
_set_scope(scope);
@@ -76,7 +78,9 @@ namespace Tensorflow.Layers
{
_current_scope = scope2;
// Actually call layer
- outputs = base.__call__(new Tensor[] { inputs }, training: training);
+ outputs = base.__call__(new Tensor[] { inputs },
+ training: training,
+ state: state);
});
@@ -121,6 +125,11 @@ namespace Tensorflow.Layers
Graph init_graph = null;
VariableV1[] existing_variables = null;
+ if (synchronization == VariableSynchronization.OnRead)
+ trainable = false;
+ else if (!trainable.HasValue)
+ trainable = true;
+
if (default_graph.building_function)
{
throw new NotImplementedException("add_weight");
diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
index 9911212b..fdcc03ea 100644
--- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
@@ -66,7 +66,7 @@ namespace Tensorflow
built = true;
}
- protected override (Tensor, Tensor) call(Tensor inputs, Tensor state = null)
+ protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
// Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new[] { inputs, state }, 1);
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
index 1d071856..715c68c6 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
@@ -307,12 +307,6 @@ namespace Tensorflow.Operations
protected override void _AddOpInternal(Operation op)
{
- if(op.name == "rnn/while/basic_rnn_cell/MatMul" ||
- op.name == "rnn/while/TensorArrayReadV3")
- {
-
- }
-
Operation[] external_inputs = new Operation[0];
if (op.inputs.Length == 0)
{
@@ -412,10 +406,12 @@ namespace Tensorflow.Operations
}
if (_outer_context != null)
- {
result = _outer_context.AddValue(val);
- }
+ if (tf.get_default_graph()._nodes_by_name.Count >= 83)
+ {
+
+ }
// Create an Enter to make `result` known to this loop context.
Tensor enter = null;
tf_with(ops.control_dependencies(null), delegate
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
index e1ac0204..636b1451 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
@@ -16,6 +16,7 @@
using System;
using System.Linq;
+using static Tensorflow.Binding;
namespace Tensorflow.Operations.Initializers
{
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
index 41a4622a..a8a0e0b9 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -214,7 +214,7 @@ namespace Tensorflow.Operations
if (sequence_length != null)
throw new NotImplementedException("sequence_length != null");
else
- a = cell.__call__(input_t_t, state1);
+ a = cell.__call__(input_t_t, state: state1);
return item;
};
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
index 8e660797..9f0cb9a5 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
@@ -32,9 +32,7 @@ namespace Tensorflow
public void _control_flow_post_processing()
{
foreach(Tensor input_tensor in inputs)
- {
control_flow_util.CheckInputFromValidContext(this, input_tensor.op);
- }
if (_control_flow_context != null)
_control_flow_context.AddOp(this);
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index 8bdaaa7b..d5068f2e 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -78,6 +78,7 @@ namespace Tensorflow
#if SERIALIZABLE
[JsonIgnore]
#endif
+ bool _is_stateful;
public NodeDef node_def
{
get
@@ -173,6 +174,8 @@ namespace Tensorflow
}
}
+ _id_value = _graph._next_id();
+
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
@@ -184,6 +187,8 @@ namespace Tensorflow
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
+ _is_stateful = op_def.IsStateful;
+
// Initialize self._outputs.
output_types = new TF_DataType[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs
index 9251f867..be4aef55 100644
--- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs
@@ -71,7 +71,7 @@ namespace Tensorflow
return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope =>
{
name = scope;
- var tensorShape = _ShapeTensor(shape);
+ var tensorShape = tensor_util.shape_tensor(shape);
var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min");
var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max");
var rnd = gen_random_ops.random_uniform(tensorShape, dtype);
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index 142afe06..0989db4f 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -335,5 +335,10 @@ namespace Tensorflow
return shape;
}
+
+ public static Tensor shape_tensor(int[] shape)
+ {
+ return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape");
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
index 4b0a35fb..c79c5b7f 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -133,66 +133,69 @@ namespace Tensorflow
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);
- ops.init_scope();
- var values = init_from_fn ? new object[0] : new object[] { initial_value };
- tf_with(ops.name_scope(name, "Variable", values), scope =>
+ tf_with(ops.init_scope2(), delegate
{
- name = scope;
- if (init_from_fn)
+ var values = init_from_fn ? new object[0] : new object[] { initial_value };
+ tf_with(ops.name_scope(name, "Variable", values), scope =>
{
- // Use attr_scope and device(None) to simulate the behavior of
- // colocate_with when the variable we want to colocate with doesn't
- // yet exist.
- string true_name = ops.name_from_scope_name(name);
- var attr = new AttrValue
+ name = scope;
+
+ if (init_from_fn)
{
- List = new AttrValue.Types.ListValue()
- };
- attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}"));
- tf_with(ops.name_scope("Initializer"), scope2 =>
+ // Use attr_scope and device(None) to simulate the behavior of
+ // colocate_with when the variable we want to colocate with doesn't
+ // yet exist.
+ string true_name = ops.name_from_scope_name(name);
+ var attr = new AttrValue
+ {
+ List = new AttrValue.Types.ListValue()
+ };
+ attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}"));
+ tf_with(ops.name_scope("Initializer"), scope2 =>
+ {
+ _initial_value = (initial_value as Func)();
+ _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
+ });
+ _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
+ }
+ // Or get the initial value from a Tensor or Python object.
+ else
{
- _initial_value = (initial_value as Func)();
- _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
- });
- _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
- }
- // Or get the initial value from a Tensor or Python object.
- else
- {
- _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype);
+ _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype);
- var shape = _initial_value.shape;
- dtype = _initial_value.dtype;
- _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
- }
+ var shape = _initial_value.shape;
+ dtype = _initial_value.dtype;
+ _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
+ }
- // Manually overrides the variable's shape with the initial value's.
- if (validate_shape)
- {
- var initial_value_shape = _initial_value.TensorShape;
- if (!initial_value_shape.is_fully_defined())
- throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
- }
+ // Manually overrides the variable's shape with the initial value's.
+ if (validate_shape)
+ {
+ var initial_value_shape = _initial_value.TensorShape;
+ if (!initial_value_shape.is_fully_defined())
+ throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
+ }
- // If 'initial_value' makes use of other variables, make sure we don't
- // have an issue if these other variables aren't initialized first by
- // using their initialized_value() method.
- var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value);
+ // If 'initial_value' makes use of other variables, make sure we don't
+ // have an issue if these other variables aren't initialized first by
+ // using their initialized_value() method.
+ var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value);
- _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;
+ _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;
- if (!String.IsNullOrEmpty(caching_device))
- {
+ if (!String.IsNullOrEmpty(caching_device))
+ {
- }
- else
- {
- ops.colocate_with(_initializer_op);
+ }
+ else
+ {
+ ops.colocate_with(_initializer_op);
- _snapshot = gen_array_ops.identity(_variable, name = "read");
- }
+ _snapshot = gen_array_ops.identity(_variable, name = "read");
+ }
- ops.add_to_collections(collections, this as VariableV1);
+ ops.add_to_collections(collections, this as VariableV1);
+ });
});
}
diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index 3549b07e..02417594 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -186,12 +186,7 @@ namespace Tensorflow
/// operations constructed within the context.
///
public static _ControlDependenciesController control_dependencies(object[] control_inputs)
- {
- return get_default_graph().control_dependencies(control_inputs);
- }
-
- public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
- => control_dependencies(control_inputs == null ? null : control_inputs.OfType