From fb059e5bda0bae3c927676261196cc3e6f1aa9ce Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 21 Dec 2023 10:32:52 +0800 Subject: [PATCH] [MNT] improve MNIST performance --- examples/mnist_add/mnist_add.ipynb | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/mnist_add/mnist_add.ipynb b/examples/mnist_add/mnist_add.ipynb index 6f407aa..a69ab22 100644 --- a/examples/mnist_add/mnist_add.ipynb +++ b/examples/mnist_add/mnist_add.ipynb @@ -181,7 +181,7 @@ "source": [ "cls = LeNet5(num_classes=10)\n", "loss_fn = nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)\n", + "optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9)\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "base_model = BasicNN(\n", @@ -357,7 +357,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -390,7 +390,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -409,7 +409,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -455,7 +455,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.18" }, "orig_nbformat": 4, "vscode": {