| @@ -0,0 +1,23 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Framework | |||
| { | |||
| public class smart_module | |||
| { | |||
| public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||
| { | |||
| return control_flow_ops.cond(pred, | |||
| true_fn: true_fn, | |||
| false_fn: false_fn, | |||
| name: name); | |||
| } | |||
| public static bool smart_constant_value(Tensor pred) | |||
| { | |||
| var pred_value = tensor_util.constant_value(pred); | |||
| return pred_value; | |||
| } | |||
| } | |||
| } | |||
| @@ -20,7 +20,7 @@ namespace Tensorflow | |||
| private Dictionary<string, int> _names_in_use; | |||
| public int _version; | |||
| private int _next_id_counter; | |||
| private List<String> _unfetchable_ops = new List<string>(); | |||
| private List<Operation> _unfetchable_ops = new List<Operation>(); | |||
| private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | |||
| public string _name_stack = ""; | |||
| @@ -228,13 +228,13 @@ namespace Tensorflow | |||
| public bool is_fetchable<T>(T tensor_or_op) | |||
| { | |||
| if (tensor_or_op is Tensor) | |||
| if (tensor_or_op is Tensor tensor) | |||
| { | |||
| return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ; | |||
| return !_unfetchable_ops.Contains(tensor); ; | |||
| } | |||
| else if (tensor_or_op is Operation) | |||
| else if (tensor_or_op is Operation op) | |||
| { | |||
| return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); | |||
| return !_unfetchable_ops.Contains(op); | |||
| } | |||
| return false; | |||
| @@ -372,6 +372,11 @@ namespace Tensorflow | |||
| _unfeedable_tensors.Add(tensor); | |||
| } | |||
| public void prevent_fetching(Operation op) | |||
| { | |||
| _unfetchable_ops.Add(op); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| c_api.TF_DeleteGraph(_handle); | |||
| @@ -48,6 +48,7 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| public Tensor __call__(Tensor inputs, | |||
| Tensor training = null, | |||
| VariableScope scope = null) | |||
| { | |||
| var input_list = new Tensor[] { inputs }; | |||
| @@ -73,7 +74,7 @@ namespace Tensorflow.Keras.Engine | |||
| // Symbolic execution on symbolic tensors. We will attempt to build | |||
| // the corresponding TF subgraph inside `backend.get_graph()` | |||
| var graph = backend.get_graph(); | |||
| outputs = call(inputs); | |||
| outputs = call(inputs, training: training); | |||
| _handle_activity_regularization(inputs, outputs); | |||
| _set_mask_metadata(inputs, outputs, null); | |||
| } | |||
| @@ -100,7 +101,7 @@ namespace Tensorflow.Keras.Engine | |||
| return null; | |||
| } | |||
| protected virtual Tensor call(Tensor inputs) | |||
| protected virtual Tensor call(Tensor inputs, Tensor training = null) | |||
| { | |||
| throw new NotImplementedException("Layer.call"); | |||
| } | |||
| @@ -143,13 +144,15 @@ namespace Tensorflow.Keras.Engine | |||
| protected virtual void _init_set_name(string name) | |||
| { | |||
| if (string.IsNullOrEmpty(name)) | |||
| (_name, _base_name) = _make_unique_name(); | |||
| string base_name = name; | |||
| if (name == null) | |||
| (_name, base_name) = _make_unique_name(); | |||
| _base_name = base_name; | |||
| } | |||
| protected virtual (string, string) _make_unique_name() | |||
| { | |||
| string base_name = "conv2d"; | |||
| string base_name = generic_utils.to_snake_case(this.GetType().Name); | |||
| string name = base_layer_utils.unique_layer_name(base_name); | |||
| return (name, base_name); | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Layers; | |||
| namespace Tensorflow.Keras.Layers | |||
| @@ -25,6 +26,7 @@ namespace Tensorflow.Keras.Layers | |||
| private RefVariable gamma; | |||
| private RefVariable beta; | |||
| private RefVariable moving_mean; | |||
| private RefVariable moving_variance; | |||
| public BatchNormalization(int axis = -1, | |||
| float momentum = 0.99f, | |||
| @@ -103,7 +105,56 @@ namespace Tensorflow.Keras.Layers | |||
| moving_mean = add_weight("moving_mean", | |||
| param_shape, | |||
| dtype: param_dtype); | |||
| dtype: param_dtype, | |||
| initializer: moving_mean_initializer, | |||
| synchronization: VariableSynchronization.ON_READ, | |||
| trainable: false, | |||
| aggregation: VariableAggregation.MEAN); | |||
| moving_variance = add_weight("moving_variance", | |||
| shape: param_shape, | |||
| dtype: param_dtype, | |||
| initializer: moving_variance_initializer, | |||
| synchronization: VariableSynchronization.ON_READ, | |||
| trainable: false, | |||
| aggregation: VariableAggregation.MEAN); | |||
| if (renorm) | |||
| throw new NotImplementedException("build when renorm is true"); | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, Tensor training = null) | |||
| { | |||
| Tensor outputs = null; | |||
| if (fused) | |||
| { | |||
| outputs = _fused_batch_norm(inputs, training: training); | |||
| } | |||
| throw new NotImplementedException("BatchNormalization call"); | |||
| } | |||
| private Tensor _fused_batch_norm(Tensor inputs, Tensor training) | |||
| { | |||
| var beta = this.beta; | |||
| var gamma = this.gamma; | |||
| Action _fused_batch_norm_training = () => | |||
| { | |||
| }; | |||
| Action _fused_batch_norm_inference = () => | |||
| { | |||
| }; | |||
| tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); | |||
| throw new NotImplementedException("_fused_batch_norm"); | |||
| } | |||
| } | |||
| } | |||
| @@ -91,7 +91,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs) | |||
| protected override Tensor call(Tensor inputs, Tensor training = null) | |||
| { | |||
| var outputs = _convolution_op.__call__(inputs, kernel); | |||
| if (use_bias) | |||
| @@ -0,0 +1,20 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.Utils | |||
| { | |||
| public class generic_utils | |||
| { | |||
| public static string to_snake_case(string name) | |||
| { | |||
| return string.Concat(name.Select((x, i) => | |||
| { | |||
| return i > 0 && char.IsUpper(x) && !Char.IsDigit(name[i - 1]) ? | |||
| "_" + x.ToString() : | |||
| x.ToString(); | |||
| })).ToLower(); | |||
| } | |||
| } | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Framework; | |||
| namespace Tensorflow.Keras.Utils | |||
| { | |||
| @@ -16,5 +17,13 @@ namespace Tensorflow.Keras.Utils | |||
| { | |||
| return true; | |||
| } | |||
| public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||
| { | |||
| return smart_module.smart_cond(pred, | |||
| true_fn: true_fn, | |||
| false_fn: false_fn, | |||
| name: name); | |||
| } | |||
| } | |||
| } | |||
| @@ -29,10 +29,11 @@ namespace Tensorflow.Layers | |||
| public virtual Tensor apply(Tensor inputs, Tensor training = null) | |||
| { | |||
| return __call__(inputs); | |||
| return __call__(inputs, training: training); | |||
| } | |||
| public Tensor __call__(Tensor inputs, | |||
| Tensor training = null, | |||
| VariableScope scope = null) | |||
| { | |||
| _set_scope(scope); | |||
| @@ -51,7 +52,7 @@ namespace Tensorflow.Layers | |||
| Python.with(scope_context_manager, scope2 => _current_scope = scope2); | |||
| // Actually call layer | |||
| var outputs = base.__call__(inputs); | |||
| var outputs = base.__call__(inputs, training: training); | |||
| // Update global default collections. | |||
| //_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS); | |||
| @@ -63,7 +64,9 @@ namespace Tensorflow.Layers | |||
| int[] shape, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| IInitializer initializer = null, | |||
| bool? trainable = null) | |||
| bool? trainable = null, | |||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||
| { | |||
| var default_graph = ops.get_default_graph(); | |||
| Graph init_graph = null; | |||
| @@ -135,5 +135,53 @@ namespace Tensorflow | |||
| else | |||
| return gen_array_ops.identity(data, name: name); | |||
| } | |||
| public static (Tensor, Tensor) cond(Tensor pred, | |||
| Action true_fn = null, | |||
| Action false_fn = null, | |||
| bool strict = false, | |||
| string name = null) | |||
| { | |||
| return with(ops.name_scope(name, "cond", new { pred }), delegate | |||
| { | |||
| // Add the Switch to the graph. | |||
| var (p_2, p_1) = @switch(pred, pred); | |||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||
| var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | |||
| pred = array_ops.identity(pred, name: "pred_id"); | |||
| // Disable the fetching of tensors that are only on one branch of cond. | |||
| foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) | |||
| tensor.op.graph.prevent_fetching(tensor.op); | |||
| return (p_2, p_1); | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Forwards `data` to an output determined by `pred`. | |||
| /// </summary> | |||
| /// <param name="data"></param> | |||
| /// <param name="pred"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="name"></param> | |||
| public static (Tensor, Tensor) @switch(Tensor data, | |||
| Tensor pred, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| string name = null) | |||
| { | |||
| return with(ops.name_scope(name, "Switch", new { data, pred }), scope => | |||
| { | |||
| name = scope; | |||
| data = ops.internal_convert_to_tensor_or_indexed_slices(data, | |||
| dtype: dtype, | |||
| name: "data", | |||
| as_ref: true); | |||
| pred = ops.convert_to_tensor(pred, name: "pred"); | |||
| return gen_control_flow_ops.@switch(data, pred, name: name); | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||
| public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("Placeholder", args: new { dtype, shape }); | |||
| var _op = _op_def_lib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); | |||
| var _result = _op.outputs; | |||
| var _inputs_flat = _op.inputs; | |||
| @@ -14,5 +14,12 @@ namespace Tensorflow | |||
| return _op; | |||
| } | |||
| public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); | |||
| return (_op.outputs[0], _op.outputs[1]); | |||
| } | |||
| } | |||
| } | |||
| @@ -4,6 +4,9 @@ using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Indicates when a distributed variable will be synced. | |||
| /// </summary> | |||
| public enum VariableSynchronization | |||
| { | |||
| AUTO = 0, | |||