| @@ -0,0 +1,23 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Framework | |||||
| { | |||||
| public class smart_module | |||||
| { | |||||
| public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||||
| { | |||||
| return control_flow_ops.cond(pred, | |||||
| true_fn: true_fn, | |||||
| false_fn: false_fn, | |||||
| name: name); | |||||
| } | |||||
| public static bool smart_constant_value(Tensor pred) | |||||
| { | |||||
| var pred_value = tensor_util.constant_value(pred); | |||||
| return pred_value; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -20,7 +20,7 @@ namespace Tensorflow | |||||
| private Dictionary<string, int> _names_in_use; | private Dictionary<string, int> _names_in_use; | ||||
| public int _version; | public int _version; | ||||
| private int _next_id_counter; | private int _next_id_counter; | ||||
| private List<String> _unfetchable_ops = new List<string>(); | |||||
| private List<Operation> _unfetchable_ops = new List<Operation>(); | |||||
| private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | ||||
| public string _name_stack = ""; | public string _name_stack = ""; | ||||
| @@ -228,13 +228,13 @@ namespace Tensorflow | |||||
| public bool is_fetchable<T>(T tensor_or_op) | public bool is_fetchable<T>(T tensor_or_op) | ||||
| { | { | ||||
| if (tensor_or_op is Tensor) | |||||
| if (tensor_or_op is Tensor tensor) | |||||
| { | { | ||||
| return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ; | |||||
| return !_unfetchable_ops.Contains(tensor); ; | |||||
| } | } | ||||
| else if (tensor_or_op is Operation) | |||||
| else if (tensor_or_op is Operation op) | |||||
| { | { | ||||
| return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); | |||||
| return !_unfetchable_ops.Contains(op); | |||||
| } | } | ||||
| return false; | return false; | ||||
| @@ -372,6 +372,11 @@ namespace Tensorflow | |||||
| _unfeedable_tensors.Add(tensor); | _unfeedable_tensors.Add(tensor); | ||||
| } | } | ||||
| public void prevent_fetching(Operation op) | |||||
| { | |||||
| _unfetchable_ops.Add(op); | |||||
| } | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| c_api.TF_DeleteGraph(_handle); | c_api.TF_DeleteGraph(_handle); | ||||
| @@ -48,6 +48,7 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| public Tensor __call__(Tensor inputs, | public Tensor __call__(Tensor inputs, | ||||
| Tensor training = null, | |||||
| VariableScope scope = null) | VariableScope scope = null) | ||||
| { | { | ||||
| var input_list = new Tensor[] { inputs }; | var input_list = new Tensor[] { inputs }; | ||||
| @@ -73,7 +74,7 @@ namespace Tensorflow.Keras.Engine | |||||
| // Symbolic execution on symbolic tensors. We will attempt to build | // Symbolic execution on symbolic tensors. We will attempt to build | ||||
| // the corresponding TF subgraph inside `backend.get_graph()` | // the corresponding TF subgraph inside `backend.get_graph()` | ||||
| var graph = backend.get_graph(); | var graph = backend.get_graph(); | ||||
| outputs = call(inputs); | |||||
| outputs = call(inputs, training: training); | |||||
| _handle_activity_regularization(inputs, outputs); | _handle_activity_regularization(inputs, outputs); | ||||
| _set_mask_metadata(inputs, outputs, null); | _set_mask_metadata(inputs, outputs, null); | ||||
| } | } | ||||
| @@ -100,7 +101,7 @@ namespace Tensorflow.Keras.Engine | |||||
| return null; | return null; | ||||
| } | } | ||||
| protected virtual Tensor call(Tensor inputs) | |||||
| protected virtual Tensor call(Tensor inputs, Tensor training = null) | |||||
| { | { | ||||
| throw new NotImplementedException("Layer.call"); | throw new NotImplementedException("Layer.call"); | ||||
| } | } | ||||
| @@ -143,13 +144,15 @@ namespace Tensorflow.Keras.Engine | |||||
| protected virtual void _init_set_name(string name) | protected virtual void _init_set_name(string name) | ||||
| { | { | ||||
| if (string.IsNullOrEmpty(name)) | |||||
| (_name, _base_name) = _make_unique_name(); | |||||
| string base_name = name; | |||||
| if (name == null) | |||||
| (_name, base_name) = _make_unique_name(); | |||||
| _base_name = base_name; | |||||
| } | } | ||||
| protected virtual (string, string) _make_unique_name() | protected virtual (string, string) _make_unique_name() | ||||
| { | { | ||||
| string base_name = "conv2d"; | |||||
| string base_name = generic_utils.to_snake_case(this.GetType().Name); | |||||
| string name = base_layer_utils.unique_layer_name(base_name); | string name = base_layer_utils.unique_layer_name(base_name); | ||||
| return (name, base_name); | return (name, base_name); | ||||
| } | } | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Utils; | |||||
| using Tensorflow.Layers; | using Tensorflow.Layers; | ||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| @@ -25,6 +26,7 @@ namespace Tensorflow.Keras.Layers | |||||
| private RefVariable gamma; | private RefVariable gamma; | ||||
| private RefVariable beta; | private RefVariable beta; | ||||
| private RefVariable moving_mean; | private RefVariable moving_mean; | ||||
| private RefVariable moving_variance; | |||||
| public BatchNormalization(int axis = -1, | public BatchNormalization(int axis = -1, | ||||
| float momentum = 0.99f, | float momentum = 0.99f, | ||||
| @@ -103,7 +105,56 @@ namespace Tensorflow.Keras.Layers | |||||
| moving_mean = add_weight("moving_mean", | moving_mean = add_weight("moving_mean", | ||||
| param_shape, | param_shape, | ||||
| dtype: param_dtype); | |||||
| dtype: param_dtype, | |||||
| initializer: moving_mean_initializer, | |||||
| synchronization: VariableSynchronization.ON_READ, | |||||
| trainable: false, | |||||
| aggregation: VariableAggregation.MEAN); | |||||
| moving_variance = add_weight("moving_variance", | |||||
| shape: param_shape, | |||||
| dtype: param_dtype, | |||||
| initializer: moving_variance_initializer, | |||||
| synchronization: VariableSynchronization.ON_READ, | |||||
| trainable: false, | |||||
| aggregation: VariableAggregation.MEAN); | |||||
| if (renorm) | |||||
| throw new NotImplementedException("build when renorm is true"); | |||||
| built = true; | |||||
| } | |||||
| protected override Tensor call(Tensor inputs, Tensor training = null) | |||||
| { | |||||
| Tensor outputs = null; | |||||
| if (fused) | |||||
| { | |||||
| outputs = _fused_batch_norm(inputs, training: training); | |||||
| } | |||||
| throw new NotImplementedException("BatchNormalization call"); | |||||
| } | |||||
| private Tensor _fused_batch_norm(Tensor inputs, Tensor training) | |||||
| { | |||||
| var beta = this.beta; | |||||
| var gamma = this.gamma; | |||||
| Action _fused_batch_norm_training = () => | |||||
| { | |||||
| }; | |||||
| Action _fused_batch_norm_inference = () => | |||||
| { | |||||
| }; | |||||
| tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); | |||||
| throw new NotImplementedException("_fused_batch_norm"); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -91,7 +91,7 @@ namespace Tensorflow.Keras.Layers | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensor call(Tensor inputs) | |||||
| protected override Tensor call(Tensor inputs, Tensor training = null) | |||||
| { | { | ||||
| var outputs = _convolution_op.__call__(inputs, kernel); | var outputs = _convolution_op.__call__(inputs, kernel); | ||||
| if (use_bias) | if (use_bias) | ||||
| @@ -0,0 +1,20 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Utils | |||||
| { | |||||
| public class generic_utils | |||||
| { | |||||
| public static string to_snake_case(string name) | |||||
| { | |||||
| return string.Concat(name.Select((x, i) => | |||||
| { | |||||
| return i > 0 && char.IsUpper(x) && !Char.IsDigit(name[i - 1]) ? | |||||
| "_" + x.ToString() : | |||||
| x.ToString(); | |||||
| })).ToLower(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Framework; | |||||
| namespace Tensorflow.Keras.Utils | namespace Tensorflow.Keras.Utils | ||||
| { | { | ||||
| @@ -16,5 +17,13 @@ namespace Tensorflow.Keras.Utils | |||||
| { | { | ||||
| return true; | return true; | ||||
| } | } | ||||
| public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||||
| { | |||||
| return smart_module.smart_cond(pred, | |||||
| true_fn: true_fn, | |||||
| false_fn: false_fn, | |||||
| name: name); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -29,10 +29,11 @@ namespace Tensorflow.Layers | |||||
| public virtual Tensor apply(Tensor inputs, Tensor training = null) | public virtual Tensor apply(Tensor inputs, Tensor training = null) | ||||
| { | { | ||||
| return __call__(inputs); | |||||
| return __call__(inputs, training: training); | |||||
| } | } | ||||
| public Tensor __call__(Tensor inputs, | public Tensor __call__(Tensor inputs, | ||||
| Tensor training = null, | |||||
| VariableScope scope = null) | VariableScope scope = null) | ||||
| { | { | ||||
| _set_scope(scope); | _set_scope(scope); | ||||
| @@ -51,7 +52,7 @@ namespace Tensorflow.Layers | |||||
| Python.with(scope_context_manager, scope2 => _current_scope = scope2); | Python.with(scope_context_manager, scope2 => _current_scope = scope2); | ||||
| // Actually call layer | // Actually call layer | ||||
| var outputs = base.__call__(inputs); | |||||
| var outputs = base.__call__(inputs, training: training); | |||||
| // Update global default collections. | // Update global default collections. | ||||
| //_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS); | //_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS); | ||||
| @@ -63,7 +64,9 @@ namespace Tensorflow.Layers | |||||
| int[] shape, | int[] shape, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| IInitializer initializer = null, | IInitializer initializer = null, | ||||
| bool? trainable = null) | |||||
| bool? trainable = null, | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||||
| { | { | ||||
| var default_graph = ops.get_default_graph(); | var default_graph = ops.get_default_graph(); | ||||
| Graph init_graph = null; | Graph init_graph = null; | ||||
| @@ -135,5 +135,53 @@ namespace Tensorflow | |||||
| else | else | ||||
| return gen_array_ops.identity(data, name: name); | return gen_array_ops.identity(data, name: name); | ||||
| } | } | ||||
| public static (Tensor, Tensor) cond(Tensor pred, | |||||
| Action true_fn = null, | |||||
| Action false_fn = null, | |||||
| bool strict = false, | |||||
| string name = null) | |||||
| { | |||||
| return with(ops.name_scope(name, "cond", new { pred }), delegate | |||||
| { | |||||
| // Add the Switch to the graph. | |||||
| var (p_2, p_1) = @switch(pred, pred); | |||||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||||
| var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | |||||
| pred = array_ops.identity(pred, name: "pred_id"); | |||||
| // Disable the fetching of tensors that are only on one branch of cond. | |||||
| foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) | |||||
| tensor.op.graph.prevent_fetching(tensor.op); | |||||
| return (p_2, p_1); | |||||
| }); | |||||
| } | |||||
| /// <summary> | |||||
| /// Forwards `data` to an output determined by `pred`. | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="pred"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="name"></param> | |||||
| public static (Tensor, Tensor) @switch(Tensor data, | |||||
| Tensor pred, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| string name = null) | |||||
| { | |||||
| return with(ops.name_scope(name, "Switch", new { data, pred }), scope => | |||||
| { | |||||
| name = scope; | |||||
| data = ops.internal_convert_to_tensor_or_indexed_slices(data, | |||||
| dtype: dtype, | |||||
| name: "data", | |||||
| as_ref: true); | |||||
| pred = ops.convert_to_tensor(pred, name: "pred"); | |||||
| return gen_control_flow_ops.@switch(data, pred, name: name); | |||||
| }); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||||
| public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Placeholder", args: new { dtype, shape }); | |||||
| var _op = _op_def_lib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); | |||||
| var _result = _op.outputs; | var _result = _op.outputs; | ||||
| var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
| @@ -14,5 +14,12 @@ namespace Tensorflow | |||||
| return _op; | return _op; | ||||
| } | } | ||||
| public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); | |||||
| return (_op.outputs[0], _op.outputs[1]); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,6 +4,9 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Indicates when a distributed variable will be synced. | |||||
| /// </summary> | |||||
| public enum VariableSynchronization | public enum VariableSynchronization | ||||
| { | { | ||||
| AUTO = 0, | AUTO = 0, | ||||