diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs new file mode 100644 index 00000000..2ba80cbc --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/smart_module.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Framework +{ + public class smart_module + { + public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) + { + return control_flow_ops.cond(pred, + true_fn: true_fn, + false_fn: false_fn, + name: name); + } + + public static bool smart_constant_value(Tensor pred) + { + var pred_value = tensor_util.constant_value(pred); + return pred_value; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 8a5d929f..71faa045 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -20,7 +20,7 @@ namespace Tensorflow private Dictionary _names_in_use; public int _version; private int _next_id_counter; - private List _unfetchable_ops = new List(); + private List _unfetchable_ops = new List(); private List _unfeedable_tensors = new List(); public string _name_stack = ""; @@ -228,13 +228,13 @@ namespace Tensorflow public bool is_fetchable(T tensor_or_op) { - if (tensor_or_op is Tensor) + if (tensor_or_op is Tensor tensor) { - return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ; + return !_unfetchable_ops.Contains(tensor); ; } - else if (tensor_or_op is Operation) + else if (tensor_or_op is Operation op) { - return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); + return !_unfetchable_ops.Contains(op); } return false; @@ -372,6 +372,11 @@ namespace Tensorflow _unfeedable_tensors.Add(tensor); } + public void prevent_fetching(Operation op) + { + _unfetchable_ops.Add(op); + } + public void Dispose() { c_api.TF_DeleteGraph(_handle); diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 48955d11..2e442e65 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -48,6 +48,7 @@ namespace Tensorflow.Keras.Engine } public Tensor __call__(Tensor inputs, + Tensor training = null, VariableScope scope = null) { var input_list = new Tensor[] { inputs }; @@ -73,7 +74,7 @@ namespace Tensorflow.Keras.Engine // Symbolic execution on symbolic tensors. We will attempt to build // the corresponding TF subgraph inside `backend.get_graph()` var graph = backend.get_graph(); - outputs = call(inputs); + outputs = call(inputs, training: training); _handle_activity_regularization(inputs, outputs); _set_mask_metadata(inputs, outputs, null); } @@ -100,7 +101,7 @@ namespace Tensorflow.Keras.Engine return null; } - protected virtual Tensor call(Tensor inputs) + protected virtual Tensor call(Tensor inputs, Tensor training = null) { throw new NotImplementedException("Layer.call"); } @@ -143,13 +144,15 @@ namespace Tensorflow.Keras.Engine protected virtual void _init_set_name(string name) { - if (string.IsNullOrEmpty(name)) - (_name, _base_name) = _make_unique_name(); + string base_name = name; + if (name == null) + (_name, base_name) = _make_unique_name(); + _base_name = base_name; } protected virtual (string, string) _make_unique_name() { - string base_name = "conv2d"; + string base_name = generic_utils.to_snake_case(this.GetType().Name); string name = base_layer_utils.unique_layer_name(base_name); return (name, base_name); } diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 80d3d655..8f82983e 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Keras.Utils; using Tensorflow.Layers; namespace Tensorflow.Keras.Layers @@ -25,6 +26,7 @@ namespace Tensorflow.Keras.Layers private RefVariable gamma; private RefVariable beta; private RefVariable moving_mean; + private RefVariable moving_variance; public BatchNormalization(int axis = -1, float momentum = 0.99f, @@ -103,7 +105,56 @@ namespace Tensorflow.Keras.Layers moving_mean = add_weight("moving_mean", param_shape, - dtype: param_dtype); + dtype: param_dtype, + initializer: moving_mean_initializer, + synchronization: VariableSynchronization.ON_READ, + trainable: false, + aggregation: VariableAggregation.MEAN); + + moving_variance = add_weight("moving_variance", + shape: param_shape, + dtype: param_dtype, + initializer: moving_variance_initializer, + synchronization: VariableSynchronization.ON_READ, + trainable: false, + aggregation: VariableAggregation.MEAN); + + if (renorm) + throw new NotImplementedException("build when renorm is true"); + + built = true; + } + + protected override Tensor call(Tensor inputs, Tensor training = null) + { + Tensor outputs = null; + + if (fused) + { + outputs = _fused_batch_norm(inputs, training: training); + } + + throw new NotImplementedException("BatchNormalization call"); + } + + private Tensor _fused_batch_norm(Tensor inputs, Tensor training) + { + var beta = this.beta; + var gamma = this.gamma; + + Action _fused_batch_norm_training = () => + { + + }; + + Action _fused_batch_norm_inference = () => + { + + }; + + tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); + + throw new NotImplementedException("_fused_batch_norm"); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index 6661e43f..bc80ff72 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -91,7 +91,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor call(Tensor inputs) + protected override Tensor call(Tensor inputs, Tensor training = null) { var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) diff --git a/src/TensorFlowNET.Core/Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/generic_utils.cs new file mode 100644 index 00000000..e6166e24 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Utils/generic_utils.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Keras.Utils +{ + public class generic_utils + { + public static string to_snake_case(string name) + { + return string.Concat(name.Select((x, i) => + { + return i > 0 && char.IsUpper(x) && !Char.IsDigit(name[i - 1]) ? + "_" + x.ToString() : + x.ToString(); + })).ToLower(); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs index 895f08c2..9a7d5ea1 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Framework; namespace Tensorflow.Keras.Utils { @@ -16,5 +17,13 @@ namespace Tensorflow.Keras.Utils { return true; } + + public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) + { + return smart_module.smart_cond(pred, + true_fn: true_fn, + false_fn: false_fn, + name: name); + } } } diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 997153be..1ca856f0 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -29,10 +29,11 @@ namespace Tensorflow.Layers public virtual Tensor apply(Tensor inputs, Tensor training = null) { - return __call__(inputs); + return __call__(inputs, training: training); } public Tensor __call__(Tensor inputs, + Tensor training = null, VariableScope scope = null) { _set_scope(scope); @@ -51,7 +52,7 @@ namespace Tensorflow.Layers Python.with(scope_context_manager, scope2 => _current_scope = scope2); // Actually call layer - var outputs = base.__call__(inputs); + var outputs = base.__call__(inputs, training: training); // Update global default collections. //_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS); @@ -63,7 +64,9 @@ namespace Tensorflow.Layers int[] shape, TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, - bool? trainable = null) + bool? trainable = null, + VariableSynchronization synchronization = VariableSynchronization.AUTO, + VariableAggregation aggregation = VariableAggregation.NONE) { var default_graph = ops.get_default_graph(); Graph init_graph = null; diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 129e3256..e9b75ab8 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -135,5 +135,53 @@ namespace Tensorflow else return gen_array_ops.identity(data, name: name); } + + public static (Tensor, Tensor) cond(Tensor pred, + Action true_fn = null, + Action false_fn = null, + bool strict = false, + string name = null) + { + return with(ops.name_scope(name, "cond", new { pred }), delegate + { + // Add the Switch to the graph. + var (p_2, p_1) = @switch(pred, pred); + var pivot_1 = array_ops.identity(p_1, name: "switch_t"); + var pivot_2 = array_ops.identity(p_2, name: "switch_f"); + pred = array_ops.identity(pred, name: "pred_id"); + + // Disable the fetching of tensors that are only on one branch of cond. + foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) + tensor.op.graph.prevent_fetching(tensor.op); + + return (p_2, p_1); + }); + } + + /// + /// Forwards `data` to an output determined by `pred`. + /// + /// + /// + /// + /// + public static (Tensor, Tensor) @switch(Tensor data, + Tensor pred, + TF_DataType dtype = TF_DataType.DtInvalid, + string name = null) + { + return with(ops.name_scope(name, "Switch", new { data, pred }), scope => + { + name = scope; + data = ops.internal_convert_to_tensor_or_indexed_slices(data, + dtype: dtype, + name: "data", + as_ref: true); + + pred = ops.convert_to_tensor(pred, name: "pred"); + + return gen_control_flow_ops.@switch(data, pred, name: name); + }); + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 59b77a1b..de8cc9d5 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -42,7 +42,7 @@ namespace Tensorflow public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) { - var _op = _op_def_lib._apply_op_helper("Placeholder", args: new { dtype, shape }); + var _op = _op_def_lib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); var _result = _op.outputs; var _inputs_flat = _op.inputs; diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs index 6b3e7107..faedfae4 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs @@ -14,5 +14,12 @@ namespace Tensorflow return _op; } + + public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); + + return (_op.outputs[0], _op.outputs[1]); + } } } diff --git a/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs b/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs index 9d184cff..8a16f285 100644 --- a/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs +++ b/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs @@ -4,6 +4,9 @@ using System.Text; namespace Tensorflow { + /// + /// Indicates when a distributed variable will be synced. + /// public enum VariableSynchronization { AUTO = 0,