| @@ -19,5 +19,16 @@ namespace Tensorflow | |||
| /// </returns> | |||
| public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) | |||
| => array_ops.expand_dims(input, axis, name, dim); | |||
| /// <summary> | |||
| /// Transposes `a`. Permutes the dimensions according to `perm`. | |||
| /// </summary> | |||
| /// <param name="a"></param> | |||
| /// <param name="perm"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="conjugate"></param> | |||
| /// <returns></returns> | |||
| public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false) | |||
| => array_ops.transpose(a, perm, name, conjugate); | |||
| } | |||
| } | |||
| @@ -46,6 +46,45 @@ namespace Tensorflow | |||
| return layer.apply(inputs); | |||
| } | |||
| /// <summary> | |||
| /// Functional interface for the batch normalization layer. | |||
| /// http://arxiv.org/abs/1502.03167 | |||
| /// </summary> | |||
| /// <param name="inputs"></param> | |||
| /// <param name="axis"></param> | |||
| /// <param name="momentum"></param> | |||
| /// <param name="epsilon"></param> | |||
| /// <param name="center"></param> | |||
| /// <param name="scale"></param> | |||
| /// <param name="beta_initializer"></param> | |||
| /// <param name="gamma_initializer"></param> | |||
| /// <param name="moving_mean_initializer"></param> | |||
| /// <param name="moving_variance_initializer"></param> | |||
| /// <param name="training"></param> | |||
| /// <param name="trainable"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="renorm"></param> | |||
| /// <param name="renorm_momentum"></param> | |||
| /// <returns></returns> | |||
| public static Tensor batch_normalization(Tensor inputs, | |||
| int axis = -1, | |||
| float momentum = 0.99f, | |||
| float epsilon = 0.001f, | |||
| bool center = true, | |||
| bool scale = true, | |||
| IInitializer beta_initializer = null, | |||
| IInitializer gamma_initializer = null, | |||
| IInitializer moving_mean_initializer = null, | |||
| IInitializer moving_variance_initializer = null, | |||
| Tensor training = null, | |||
| bool trainable = true, | |||
| string name = null, | |||
| bool renorm = false, | |||
| float renorm_momentum = 0.99f) | |||
| { | |||
| throw new NotImplementedException("batch_normalization"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -30,6 +30,7 @@ namespace Tensorflow.Keras.Engine | |||
| VariableScope scope = null) | |||
| { | |||
| var input_list = new Tensor[] { inputs }; | |||
| Tensor outputs = null; | |||
| // We will attempt to build a TF graph if & only if all inputs are symbolic. | |||
| // This is always the case in graph mode. It can also be the case in eager | |||
| @@ -45,9 +46,42 @@ namespace Tensorflow.Keras.Engine | |||
| _maybe_build(inputs); | |||
| built = true; | |||
| } | |||
| if (build_graph) | |||
| { | |||
| // Symbolic execution on symbolic tensors. We will attempt to build | |||
| // the corresponding TF subgraph inside `backend.get_graph()` | |||
| var graph = backend.get_graph(); | |||
| outputs = call(inputs); | |||
| _handle_activity_regularization(inputs, outputs); | |||
| _set_mask_metadata(inputs, outputs, null); | |||
| } | |||
| }); | |||
| throw new NotImplementedException(""); | |||
| return outputs; | |||
| } | |||
| private void _handle_activity_regularization(Tensor inputs, Tensor outputs) | |||
| { | |||
| //if(_activity_regularizer != null) | |||
| { | |||
| } | |||
| } | |||
| private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask) | |||
| { | |||
| } | |||
| private Tensor compute_mask(Tensor inputs, Tensor mask = null) | |||
| { | |||
| return null; | |||
| } | |||
| protected virtual Tensor call(Tensor inputs) | |||
| { | |||
| throw new NotImplementedException("Layer.call"); | |||
| } | |||
| protected virtual string _name_scope() | |||
| @@ -90,5 +90,26 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs) | |||
| { | |||
| var outputs = _convolution_op.__call__(inputs, kernel); | |||
| if (use_bias) | |||
| { | |||
| if (data_format == "channels_first") | |||
| { | |||
| throw new NotImplementedException("call channels_first"); | |||
| } | |||
| else | |||
| { | |||
| outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC"); | |||
| } | |||
| } | |||
| if (activation != null) | |||
| return activation.Activate(outputs); | |||
| return outputs; | |||
| } | |||
| } | |||
| } | |||
| @@ -10,5 +10,10 @@ namespace Tensorflow.Keras | |||
| { | |||
| } | |||
| public static Graph get_graph() | |||
| { | |||
| return ops.get_default_graph(); | |||
| } | |||
| } | |||
| } | |||
| @@ -65,7 +65,10 @@ namespace Tensorflow.Layers | |||
| // Actually call layer | |||
| var outputs = base.__call__(inputs); | |||
| throw new NotImplementedException(""); | |||
| // Update global default collections. | |||
| //_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS); | |||
| return outputs; | |||
| } | |||
| protected virtual RefVariable add_weight(string name, | |||
| @@ -6,6 +6,6 @@ namespace Tensorflow.Operations.Activation | |||
| { | |||
| public interface IActivation | |||
| { | |||
| Tensor Activate(Tensor features, string name = null); | |||
| } | |||
| } | |||
| @@ -6,6 +6,16 @@ namespace Tensorflow.Operations.Activation | |||
| { | |||
| public class relu : IActivation | |||
| { | |||
| public Tensor Activate(Tensor features, string name = null) | |||
| { | |||
| OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new | |||
| { | |||
| features | |||
| }); | |||
| return _op.outputs[0]; | |||
| } | |||
| } | |||
| } | |||
| @@ -62,5 +62,10 @@ namespace Tensorflow.Operations | |||
| strides: strides, | |||
| name: name); | |||
| } | |||
| public Tensor __call__(Tensor inp, RefVariable filter) | |||
| { | |||
| return conv_op.__call__(inp, filter); | |||
| } | |||
| } | |||
| } | |||
| @@ -52,5 +52,18 @@ namespace Tensorflow.Operations | |||
| throw new NotImplementedException("_NonAtrousConvolution conv_dims 3"); | |||
| } | |||
| } | |||
| public Tensor __call__(Tensor inp, RefVariable filter) | |||
| { | |||
| return conv_op(new | |||
| { | |||
| input = inp, | |||
| filter, | |||
| strides, | |||
| padding, | |||
| data_format, | |||
| name | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -51,5 +51,10 @@ namespace Tensorflow.Operations | |||
| } | |||
| } | |||
| } | |||
| public Tensor __call__(Tensor inp, RefVariable filter) | |||
| { | |||
| return call.__call__(inp, filter); | |||
| } | |||
| } | |||
| } | |||
| @@ -6,9 +6,51 @@ namespace Tensorflow.Operations | |||
| { | |||
| public class gen_nn_ops | |||
| { | |||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| public static Tensor conv2d(object parameters) | |||
| { | |||
| throw new NotImplementedException("gen_nn_op.conv2d"); | |||
| var args = Python.ConvertToDict(parameters); | |||
| var input = args["input"]; | |||
| var filter = args["filter"]; | |||
| var strides = args["strides"]; | |||
| var padding = args["padding"]; | |||
| var name = args["name"]; | |||
| var data_format = args.ContainsKey("data_format") ? args["data_format"] : "NHWC"; | |||
| var use_cudnn_on_gpu = args.ContainsKey("use_cudnn_on_gpu") ? args["use_cudnn_on_gpu"] : true; | |||
| var dilations = args.ContainsKey("dilations") ? args["dilations"] : new int[] { 1, 1, 1, 1 }; | |||
| var _op = _op_def_lib._apply_op_helper("Conv2D", name: name?.ToString(), args: new | |||
| { | |||
| input, | |||
| filter, | |||
| strides, | |||
| padding, | |||
| use_cudnn_on_gpu, | |||
| data_format, | |||
| dilations | |||
| }); | |||
| return _op.outputs[0]; | |||
| } | |||
| public static Tensor bias_add(Tensor value, | |||
| Tensor bias, | |||
| string data_format = null, | |||
| string name = null) | |||
| { | |||
| if (data_format == null) | |||
| data_format = "NHWC"; | |||
| var _op = _op_def_lib._apply_op_helper("BiasAdd", name: name, args: new | |||
| { | |||
| value, | |||
| bias, | |||
| data_format | |||
| }); | |||
| return _op.outputs[0]; | |||
| } | |||
| } | |||
| } | |||
| @@ -272,5 +272,14 @@ namespace Tensorflow | |||
| { | |||
| return gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||
| } | |||
| public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false) | |||
| { | |||
| return with(ops.name_scope(name, "transpose", new { a }), scope => | |||
| { | |||
| name = scope; | |||
| return gen_array_ops.transpose(a, perm, name); | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -157,6 +157,12 @@ namespace Tensorflow | |||
| return _op.outputs[0]; | |||
| } | |||
| public static Tensor transpose(Tensor x, int[] perm, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("Transpose", name, new { x, perm }); | |||
| return _op.outputs[0]; | |||
| } | |||
| public static Tensor zeros_like(Tensor x, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("ZerosLike", name, new { x }); | |||
| @@ -20,5 +20,26 @@ namespace Tensorflow | |||
| dilation_rate, | |||
| name: name, | |||
| data_format: data_format); | |||
| /// <summary> | |||
| /// Adds `bias` to `value`. | |||
| /// </summary> | |||
| /// <param name="value"></param> | |||
| /// <param name="bias"></param> | |||
| /// <param name="data_format"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor bias_add(Tensor value, | |||
| RefVariable bias, | |||
| string data_format = null, | |||
| string name = null) | |||
| { | |||
| return Python.with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope => | |||
| { | |||
| value = ops.convert_to_tensor(value, name: "input"); | |||
| var bias_tensor = ops.convert_to_tensor(bias, dtype: value.dtype, name: "bias"); | |||
| return gen_nn_ops.bias_add(value, bias_tensor, data_format: data_format, name: name); | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -40,6 +40,10 @@ namespace Tensorflow | |||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
| /// </summary> | |||
| public static string SAVEABLE_OBJECTS = "saveable_objects"; | |||
| /// <summary> | |||
| /// Key to collect update_ops | |||
| /// </summary> | |||
| public static string UPDATE_OPS = "update_ops"; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| @@ -43,14 +44,61 @@ namespace TensorFlowNET.Examples.TextClassification | |||
| x_expanded = tf.expand_dims(x_emb, -1); | |||
| }); | |||
| Tensor conv0 = null; | |||
| Tensor conv1 = null; | |||
| // First Convolution Layer | |||
| with(tf.variable_scope("conv-0"), delegate | |||
| { | |||
| var conv0 = tf.layers.conv2d(x_expanded, | |||
| conv0 = tf.layers.conv2d(x_expanded, | |||
| filters: num_filters[0], | |||
| kernel_size: new int[] { filter_sizes[0], embedding_size }, | |||
| kernel_initializer: cnn_initializer, | |||
| activation: tf.nn.relu); | |||
| conv0 = tf.transpose(conv0, new int[] { 0, 1, 3, 2 }); | |||
| }); | |||
| with(tf.name_scope("conv-block-1"), delegate { | |||
| conv1 = conv_block(conv0, 1); | |||
| }); | |||
| } | |||
| private Tensor conv_block(Tensor input, int i, bool max_pool = true) | |||
| { | |||
| return with(tf.variable_scope($"conv-block-{i}"), delegate | |||
| { | |||
| Tensor conv = null; | |||
| // Two "conv-batch_norm-relu" layers. | |||
| foreach (var j in Enumerable.Range(0, 2)) | |||
| { | |||
| with(tf.variable_scope($"conv-{j}"), delegate | |||
| { | |||
| // convolution | |||
| conv = tf.layers.conv2d( | |||
| input, | |||
| filters: num_filters[i], | |||
| kernel_size: new int[] { filter_sizes[i], num_filters[i - 1] }, | |||
| kernel_initializer: cnn_initializer, | |||
| activation: null); | |||
| // batch normalization | |||
| conv = tf.layers.batch_normalization(conv, training: is_training); | |||
| // relu | |||
| conv = tf.nn.relu.Activate(conv); | |||
| conv = tf.transpose(conv, new int[] { 0, 1, 3, 2 }); | |||
| }); | |||
| } | |||
| if (max_pool) | |||
| { | |||
| // Max pooling | |||
| throw new NotImplementedException("conv_block"); | |||
| } | |||
| else | |||
| { | |||
| return conv; | |||
| } | |||
| }); | |||
| } | |||
| } | |||