| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
| using static Tensorflow.Python; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -101,6 +102,25 @@ namespace Tensorflow | |||||
| Tensor logits = null, string name = null) | Tensor logits = null, string name = null) | ||||
| => nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, logits: logits, name: name); | => nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, logits: logits, name: name); | ||||
| /// <summary> | |||||
| /// Computes softmax cross entropy between `logits` and `labels`. | |||||
| /// </summary> | |||||
| /// <param name="labels"></param> | |||||
| /// <param name="logits"></param> | |||||
| /// <param name="dim"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor softmax_cross_entropy_with_logits(Tensor labels, Tensor logits, int dim = -1, string name = null) | |||||
| { | |||||
| with(ops.name_scope(name, "softmax_cross_entropy_with_logits_sg", new { logits, labels }), scope => | |||||
| { | |||||
| name = scope; | |||||
| labels = array_ops.stop_gradient(labels, name: "labels_stop_gradient"); | |||||
| }); | |||||
| return softmax_cross_entropy_with_logits_v2(labels, logits, axis: dim, name: name); | |||||
| } | |||||
| public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) | public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) | ||||
| => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); | => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); | ||||
| } | } | ||||
| @@ -94,14 +94,28 @@ namespace Tensorflow | |||||
| if (attrs.ContainsKey(input_arg.TypeAttr)) | if (attrs.ContainsKey(input_arg.TypeAttr)) | ||||
| dtype = (DataType)attrs[input_arg.TypeAttr]; | dtype = (DataType)attrs[input_arg.TypeAttr]; | ||||
| else | else | ||||
| if (values is Tensor[] values1) | |||||
| dtype = values1[0].dtype.as_datatype_enum(); | |||||
| switch (values) | |||||
| { | |||||
| case Tensor[] values1: | |||||
| dtype = values1[0].dtype.as_datatype_enum(); | |||||
| break; | |||||
| case object[] values1: | |||||
| foreach(var t in values1) | |||||
| if(t is Tensor tensor) | |||||
| { | |||||
| dtype = tensor.dtype.as_datatype_enum(); | |||||
| break; | |||||
| } | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException($"can't infer the dtype for {values.GetType()}"); | |||||
| } | |||||
| if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | ||||
| default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | ||||
| } | } | ||||
| if(input_arg.IsRef && dtype != DataType.DtInvalid) | |||||
| if(!input_arg.IsRef && dtype != DataType.DtInvalid) | |||||
| dtype = dtype.as_base_dtype(); | dtype = dtype.as_base_dtype(); | ||||
| values = ops.internal_convert_n_to_tensor(values, | values = ops.internal_convert_n_to_tensor(values, | ||||
| @@ -17,9 +17,7 @@ namespace Tensorflow | |||||
| private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
| public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
| #if GRAPH_SERIALIZE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Tensor output => _outputs.FirstOrDefault(); | public Tensor output => _outputs.FirstOrDefault(); | ||||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | ||||
| @@ -1,7 +1,4 @@ | |||||
| using Google.Protobuf.Collections; | using Google.Protobuf.Collections; | ||||
| #if GRAPH_SERIALIZE | |||||
| using Newtonsoft.Json; | |||||
| #endif | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| @@ -37,21 +34,11 @@ namespace Tensorflow | |||||
| private Graph _graph; | private Graph _graph; | ||||
| public string type => OpType; | public string type => OpType; | ||||
| #if GRAPH_SERIALIZE | |||||
| [JsonIgnore] | |||||
| public Graph graph => _graph; | public Graph graph => _graph; | ||||
| [JsonIgnore] | |||||
| public int _id => _id_value; | public int _id => _id_value; | ||||
| [JsonIgnore] | |||||
| public int _id_value; | public int _id_value; | ||||
| [JsonIgnore] | |||||
| public Operation op => this; | public Operation op => this; | ||||
| #else | |||||
| public Graph graph => _graph; | |||||
| public int _id => _id_value; | |||||
| public int _id_value; | |||||
| public Operation op => this; | |||||
| #endif | |||||
| public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
| private Status status = new Status(); | private Status status = new Status(); | ||||
| @@ -60,9 +47,6 @@ namespace Tensorflow | |||||
| public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | ||||
| private NodeDef _node_def; | private NodeDef _node_def; | ||||
| #if GRAPH_SERIALIZE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public NodeDef node_def | public NodeDef node_def | ||||
| { | { | ||||
| get | get | ||||
| @@ -492,13 +492,18 @@ namespace Tensorflow | |||||
| { | { | ||||
| return with(ops.name_scope(name), scope => { | return with(ops.name_scope(name), scope => { | ||||
| var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32); | var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32); | ||||
| return identity(values[0], name = scope); | |||||
| return identity(values[0], name: scope); | |||||
| }); | }); | ||||
| } | } | ||||
| return gen_array_ops.concat_v2(values, axis, name: name); | return gen_array_ops.concat_v2(values, axis, name: name); | ||||
| } | } | ||||
| public static Tensor concat(object[] values, int axis, string name = "concat") | |||||
| { | |||||
| return gen_array_ops.concat_v2(values, axis, name: name); | |||||
| } | |||||
| public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) | public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) | ||||
| => gen_array_ops.gather_v2(@params, indices, axis, name: name); | => gen_array_ops.gather_v2(@params, indices, axis, name: name); | ||||
| @@ -19,7 +19,7 @@ namespace Tensorflow | |||||
| /// <param name="axis"></param> | /// <param name="axis"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor concat_v2(Tensor[] values, int axis, string name = null) | |||||
| public static Tensor concat_v2<T>(T[] values, int axis, string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); | var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using static Tensorflow.Python; | using static Tensorflow.Python; | ||||
| @@ -159,8 +160,9 @@ namespace Tensorflow | |||||
| int axis = -1, | int axis = -1, | ||||
| string name = null) | string name = null) | ||||
| { | { | ||||
| return Python.with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { }), scope => | |||||
| return with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { logits, labels }), scope => | |||||
| { | { | ||||
| name = scope; | |||||
| var precise_logits = logits; | var precise_logits = logits; | ||||
| var input_rank = array_ops.rank(precise_logits); | var input_rank = array_ops.rank(precise_logits); | ||||
| var shape = logits.TensorShape; | var shape = logits.TensorShape; | ||||
| @@ -170,6 +172,10 @@ namespace Tensorflow | |||||
| var input_shape = array_ops.shape(precise_logits); | var input_shape = array_ops.shape(precise_logits); | ||||
| // Make precise_logits and labels into matrices. | |||||
| precise_logits = _flatten_outer_dims(precise_logits); | |||||
| labels = _flatten_outer_dims(labels); | |||||
| // Do the actual op computation. | // Do the actual op computation. | ||||
| // The second output tensor contains the gradients. We use it in | // The second output tensor contains the gradients. We use it in | ||||
| // _CrossEntropyGrad() in nn_grad but not here. | // _CrossEntropyGrad() in nn_grad but not here. | ||||
| @@ -186,5 +192,50 @@ namespace Tensorflow | |||||
| return cost; | return cost; | ||||
| }); | }); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Flattens logits' outer dimensions and keep its last dimension. | |||||
| /// </summary> | |||||
| /// <param name="logits"></param> | |||||
| /// <returns></returns> | |||||
| private static Tensor _flatten_outer_dims(Tensor logits) | |||||
| { | |||||
| var rank = array_ops.rank(logits); | |||||
| var last_dim_size = array_ops.slice(array_ops.shape(logits), | |||||
| new[] { math_ops.subtract(rank, 1) }, | |||||
| new[] { 1 }); | |||||
| var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0); | |||||
| var output = array_ops.reshape(logits, ops); | |||||
| // Set output shape if known. | |||||
| // if not context.executing_eagerly(): | |||||
| var shape = logits.TensorShape; | |||||
| if(shape != null && shape.NDim > 0) | |||||
| { | |||||
| var product = 1; | |||||
| var product_valid = true; | |||||
| foreach(var d in shape.Dimensions.Take(shape.NDim - 1)) | |||||
| { | |||||
| if(d == -1) | |||||
| { | |||||
| product_valid = false; | |||||
| break; | |||||
| } | |||||
| else | |||||
| { | |||||
| product *= d; | |||||
| } | |||||
| } | |||||
| if (product_valid) | |||||
| { | |||||
| var output_shape = new[] { product }; | |||||
| throw new NotImplementedException("_flatten_outer_dims product_valid"); | |||||
| } | |||||
| } | |||||
| return output; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -22,21 +22,11 @@ namespace Tensorflow | |||||
| private int _id; | private int _id; | ||||
| private Operation _op; | private Operation _op; | ||||
| #if GRAPH_SERIALIZE | |||||
| [JsonIgnore] | |||||
| public int Id => _id; | |||||
| [JsonIgnore] | |||||
| public Graph graph => op?.graph; | |||||
| [JsonIgnore] | |||||
| public Operation op => _op; | |||||
| [JsonIgnore] | |||||
| public Tensor[] outputs => op.outputs; | |||||
| #else | |||||
| public int Id => _id; | public int Id => _id; | ||||
| public Graph graph => op?.graph; | public Graph graph => op?.graph; | ||||
| public Operation op => _op; | public Operation op => _op; | ||||
| public Tensor[] outputs => op.outputs; | public Tensor[] outputs => op.outputs; | ||||
| #endif | |||||
| /// <summary> | /// <summary> | ||||
| /// The string name of this tensor. | /// The string name of this tensor. | ||||
| @@ -50,18 +40,12 @@ namespace Tensorflow | |||||
| private TF_DataType _dtype = TF_DataType.DtInvalid; | private TF_DataType _dtype = TF_DataType.DtInvalid; | ||||
| public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | ||||
| #if GRAPH_SERIALIZE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | ||||
| #if GRAPH_SERIALIZE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | ||||
| public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | ||||
| #if GRAPH_SERIALIZE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | ||||
| public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
| @@ -70,9 +54,6 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// used for keep other pointer when do implicit operating | /// used for keep other pointer when do implicit operating | ||||
| /// </summary> | /// </summary> | ||||
| #if GRAPH_SERIALIZE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public object Tag { get; set; } | public object Tag { get; set; } | ||||
| public int[] shape | public int[] shape | ||||
| @@ -140,9 +121,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| #if GRAPH_SERIALIZE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int NDims => rank; | public int NDims => rank; | ||||
| public string Device => op.Device; | public string Device => op.Device; | ||||
| @@ -110,7 +110,7 @@ namespace Tensorflow.Train | |||||
| var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking); | var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking); | ||||
| operations.Add(update_beta1); | operations.Add(update_beta1); | ||||
| operations.Add(update_beta1); | |||||
| operations.Add(update_beta2); | |||||
| }); | }); | ||||
| return control_flow_ops.group(operations.ToArray(), name: name_scope); | return control_flow_ops.group(operations.ToArray(), name: name_scope); | ||||
| @@ -49,8 +49,6 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
| public Graph BuildGraph() | public Graph BuildGraph() | ||||
| { | { | ||||
| var g = tf.Graph(); | |||||
| // Placeholders for inputs (x) and outputs(y) | // Placeholders for inputs (x) and outputs(y) | ||||
| x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); | x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); | ||||
| y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); | y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); | ||||
| @@ -60,7 +58,8 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
| // Create a fully-connected layer with n_classes nodes as output layer | // Create a fully-connected layer with n_classes nodes as output layer | ||||
| var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); | var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); | ||||
| // Define the loss function, optimizer, and accuracy | // Define the loss function, optimizer, and accuracy | ||||
| loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels: y, logits: output_logits), name: "loss"); | |||||
| var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits); | |||||
| loss = tf.reduce_mean(logits, name: "loss"); | |||||
| optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); | optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); | ||||
| var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); | var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); | ||||
| accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); | ||||
| @@ -68,7 +67,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
| // Network predictions | // Network predictions | ||||
| var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); | var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); | ||||
| return g; | |||||
| return tf.get_default_graph(); | |||||
| } | } | ||||
| private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) | private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) | ||||
| @@ -93,16 +92,10 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
| return layer; | return layer; | ||||
| } | } | ||||
| public Graph ImportGraph() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public bool Predict() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public Graph ImportGraph() => throw new NotImplementedException(); | |||||
| public bool Predict() => throw new NotImplementedException(); | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | ||||
| @@ -112,7 +105,6 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
| { | { | ||||
| // Number of training iterations in each epoch | // Number of training iterations in each epoch | ||||
| var num_tr_iter = mnist.train.labels.len / batch_size; | var num_tr_iter = mnist.train.labels.len / batch_size; | ||||
| return with(tf.Session(), sess => | return with(tf.Session(), sess => | ||||
| { | { | ||||
| var init = tf.global_variables_initializer(); | var init = tf.global_variables_initializer(); | ||||
| @@ -153,10 +145,9 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
| print("---------------------------------------------------------"); | print("---------------------------------------------------------"); | ||||
| print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); | print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); | ||||
| print("---------------------------------------------------------"); | print("---------------------------------------------------------"); | ||||
| } | } | ||||
| return accuracy_val > 0.9; | |||||
| return accuracy_val > 0.95; | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -0,0 +1,164 @@ | |||||
| # imports | |||||
| import tensorflow as tf | |||||
| import numpy as np | |||||
| import matplotlib.pyplot as plt | |||||
| img_h = img_w = 28 # MNIST images are 28x28 | |||||
| img_size_flat = img_h * img_w # 28x28=784, the total number of pixels | |||||
| n_classes = 10 # Number of classes, one class per digit | |||||
| def load_data(mode='train'): | |||||
| """ | |||||
| Function to (download and) load the MNIST data | |||||
| :param mode: train or test | |||||
| :return: images and the corresponding labels | |||||
| """ | |||||
| from tensorflow.examples.tutorials.mnist import input_data | |||||
| mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) | |||||
| if mode == 'train': | |||||
| x_train, y_train, x_valid, y_valid = mnist.train.images, mnist.train.labels, \ | |||||
| mnist.validation.images, mnist.validation.labels | |||||
| return x_train, y_train, x_valid, y_valid | |||||
| elif mode == 'test': | |||||
| x_test, y_test = mnist.test.images, mnist.test.labels | |||||
| return x_test, y_test | |||||
| def randomize(x, y): | |||||
| """ Randomizes the order of data samples and their corresponding labels""" | |||||
| permutation = np.random.permutation(y.shape[0]) | |||||
| shuffled_x = x[permutation, :] | |||||
| shuffled_y = y[permutation] | |||||
| return shuffled_x, shuffled_y | |||||
| def get_next_batch(x, y, start, end): | |||||
| x_batch = x[start:end] | |||||
| y_batch = y[start:end] | |||||
| return x_batch, y_batch | |||||
| # Load MNIST data | |||||
| x_train, y_train, x_valid, y_valid = load_data(mode='train') | |||||
| print("Size of:") | |||||
| print("- Training-set:\t\t{}".format(len(y_train))) | |||||
| print("- Validation-set:\t{}".format(len(y_valid))) | |||||
| print('x_train:\t{}'.format(x_train.shape)) | |||||
| print('y_train:\t{}'.format(y_train.shape)) | |||||
| print('x_train:\t{}'.format(x_valid.shape)) | |||||
| print('y_valid:\t{}'.format(y_valid.shape)) | |||||
| print(y_valid[:5, :]) | |||||
| # Hyper-parameters | |||||
| epochs = 10 # Total number of training epochs | |||||
| batch_size = 100 # Training batch size | |||||
| display_freq = 100 # Frequency of displaying the training results | |||||
| learning_rate = 0.001 # The optimization initial learning rate | |||||
| h1 = 200 # number of nodes in the 1st hidden layer | |||||
| # weight and bais wrappers | |||||
| def weight_variable(name, shape): | |||||
| """ | |||||
| Create a weight variable with appropriate initialization | |||||
| :param name: weight name | |||||
| :param shape: weight shape | |||||
| :return: initialized weight variable | |||||
| """ | |||||
| initer = tf.truncated_normal_initializer(stddev=0.01) | |||||
| return tf.get_variable('W_' + name, | |||||
| dtype=tf.float32, | |||||
| shape=shape, | |||||
| initializer=initer) | |||||
| def bias_variable(name, shape): | |||||
| """ | |||||
| Create a bias variable with appropriate initialization | |||||
| :param name: bias variable name | |||||
| :param shape: bias variable shape | |||||
| :return: initialized bias variable | |||||
| """ | |||||
| initial = tf.constant(0., shape=shape, dtype=tf.float32) | |||||
| return tf.get_variable('b_' + name, | |||||
| dtype=tf.float32, | |||||
| initializer=initial) | |||||
| def fc_layer(x, num_units, name, use_relu=True): | |||||
| """ | |||||
| Create a fully-connected layer | |||||
| :param x: input from previous layer | |||||
| :param num_units: number of hidden units in the fully-connected layer | |||||
| :param name: layer name | |||||
| :param use_relu: boolean to add ReLU non-linearity (or not) | |||||
| :return: The output array | |||||
| """ | |||||
| in_dim = x.get_shape()[1] | |||||
| W = weight_variable(name, shape=[in_dim, num_units]) | |||||
| b = bias_variable(name, [num_units]) | |||||
| layer = tf.matmul(x, W) | |||||
| layer += b | |||||
| if use_relu: | |||||
| layer = tf.nn.relu(layer) | |||||
| return layer | |||||
| # Create the graph for the linear model | |||||
| # Placeholders for inputs (x) and outputs(y) | |||||
| x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='X') | |||||
| y = tf.placeholder(tf.float32, shape=[None, n_classes], name='Y') | |||||
| # Create a fully-connected layer with h1 nodes as hidden layer | |||||
| fc1 = fc_layer(x, h1, 'FC1', use_relu=True) | |||||
| # Create a fully-connected layer with n_classes nodes as output layer | |||||
| output_logits = fc_layer(fc1, n_classes, 'OUT', use_relu=False) | |||||
| # Define the loss function, optimizer, and accuracy | |||||
| logits = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output_logits) | |||||
| loss = tf.reduce_mean(logits, name='loss') | |||||
| optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, name='Adam-op').minimize(loss) | |||||
| correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name='correct_pred') | |||||
| accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy') | |||||
| # Network predictions | |||||
| cls_prediction = tf.argmax(output_logits, axis=1, name='predictions') | |||||
| # export graph | |||||
| #tf.train.export_meta_graph(filename='neural_network.meta', graph=tf.get_default_graph(), clear_extraneous_savers= True, as_text = True) | |||||
| # Create the op for initializing all variables | |||||
| init = tf.global_variables_initializer() | |||||
| # Create an interactive session (to keep the session in the other cells) | |||||
| sess = tf.InteractiveSession() | |||||
| # Initialize all variables | |||||
| sess.run(init) | |||||
| # Number of training iterations in each epoch | |||||
| num_tr_iter = int(len(y_train) / batch_size) | |||||
| for epoch in range(epochs): | |||||
| print('Training epoch: {}'.format(epoch + 1)) | |||||
| # Randomly shuffle the training data at the beginning of each epoch | |||||
| x_train, y_train = randomize(x_train, y_train) | |||||
| for iteration in range(num_tr_iter): | |||||
| start = iteration * batch_size | |||||
| end = (iteration + 1) * batch_size | |||||
| x_batch, y_batch = get_next_batch(x_train, y_train, start, end) | |||||
| # Run optimization op (backprop) | |||||
| feed_dict_batch = {x: x_batch, y: y_batch} | |||||
| sess.run(optimizer, feed_dict=feed_dict_batch) | |||||
| if iteration % display_freq == 0: | |||||
| # Calculate and display the batch loss and accuracy | |||||
| loss_batch, acc_batch = sess.run([loss, accuracy], | |||||
| feed_dict=feed_dict_batch) | |||||
| print("iter {0:3d}:\t Loss={1:.2f},\tTraining Accuracy={2:.01%}". | |||||
| format(iteration, loss_batch, acc_batch)) | |||||
| # Run validation after every epoch | |||||
| feed_dict_valid = {x: x_valid[:1000], y: y_valid[:1000]} | |||||
| loss_valid, acc_valid = sess.run([loss, accuracy], feed_dict=feed_dict_valid) | |||||
| print('---------------------------------------------------------') | |||||
| print("Epoch: {0}, validation loss: {1:.2f}, validation accuracy: {2:.01%}". | |||||
| format(epoch + 1, loss_valid, acc_valid)) | |||||
| print('---------------------------------------------------------') | |||||