| @@ -112,46 +112,40 @@ Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube) | |||||
| Toy version of `ResNet` in `Keras` functional API: | Toy version of `ResNet` in `Keras` functional API: | ||||
| ```csharp | ```csharp | ||||
| var layers = new LayersApi(); | |||||
| // input layer | // input layer | ||||
| var inputs = keras.Input(shape: (32, 32, 3), name: "img"); | var inputs = keras.Input(shape: (32, 32, 3), name: "img"); | ||||
| // convolutional layer | // convolutional layer | ||||
| var x = layers.Conv2D(32, 3, activation: "relu").Apply(inputs); | var x = layers.Conv2D(32, 3, activation: "relu").Apply(inputs); | ||||
| x = layers.Conv2D(64, 3, activation: "relu").Apply(x); | x = layers.Conv2D(64, 3, activation: "relu").Apply(x); | ||||
| var block_1_output = layers.MaxPooling2D(3).Apply(x); | var block_1_output = layers.MaxPooling2D(3).Apply(x); | ||||
| x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_1_output); | x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_1_output); | ||||
| x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x); | x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x); | ||||
| var block_2_output = layers.add(x, block_1_output); | |||||
| var block_2_output = layers.Add().Apply(new Tensors(x, block_1_output)); | |||||
| x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_2_output); | x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_2_output); | ||||
| x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x); | x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x); | ||||
| var block_3_output = layers.add(x, block_2_output); | |||||
| var block_3_output = layers.Add().Apply(new Tensors(x, block_2_output)); | |||||
| x = layers.Conv2D(64, 3, activation: "relu").Apply(block_3_output); | x = layers.Conv2D(64, 3, activation: "relu").Apply(block_3_output); | ||||
| x = layers.GlobalAveragePooling2D().Apply(x); | x = layers.GlobalAveragePooling2D().Apply(x); | ||||
| x = layers.Dense(256, activation: "relu").Apply(x); | x = layers.Dense(256, activation: "relu").Apply(x); | ||||
| x = layers.Dropout(0.5f).Apply(x); | x = layers.Dropout(0.5f).Apply(x); | ||||
| // output layer | // output layer | ||||
| var outputs = layers.Dense(10).Apply(x); | var outputs = layers.Dense(10).Apply(x); | ||||
| // build keras model | // build keras model | ||||
| model = keras.Model(inputs, outputs, name: "toy_resnet"); | |||||
| var model = keras.Model(inputs, outputs, name: "toy_resnet"); | |||||
| model.summary(); | model.summary(); | ||||
| // compile keras model in tensorflow static graph | // compile keras model in tensorflow static graph | ||||
| model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), | model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), | ||||
| loss: keras.losses.CategoricalCrossentropy(from_logits: true), | |||||
| metrics: new[] { "acc" }); | |||||
| loss: keras.losses.CategoricalCrossentropy(from_logits: true), | |||||
| metrics: new[] { "acc" }); | |||||
| // prepare dataset | // prepare dataset | ||||
| var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); | var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); | ||||
| x_train = x_train / 255.0f; | |||||
| y_train = np_utils.to_categorical(y_train, 10); | |||||
| // training | // training | ||||
| model.fit(x_train[new Slice(0, 1000)], y_train[new Slice(0, 1000)], | |||||
| batch_size: 64, | |||||
| epochs: 10, | |||||
| model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], | |||||
| batch_size: 64, | |||||
| epochs: 10, | |||||
| validation_split: 0.2f); | validation_split: 0.2f); | ||||
| ``` | ``` | ||||
| @@ -260,4 +254,4 @@ WeChat Sponsor 微信打赏: | |||||
| TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/) | TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/) | ||||
| <br> | <br> | ||||
| <a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp-stack.png" width="391" height="100" /></a> | |||||
| <a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp-stack.png" width="391" height="100" /></a> | |||||