{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# kNN 分类算法\n", "\n", "\n", "K最近邻(k-Nearest Neighbor,kNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:***如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别***。\n", "\n", "kNN方法虽然从原理上也依赖于[极限定理](https://baike.baidu.com/item/%E6%9E%81%E9%99%90%E5%AE%9A%E7%90%86/13672616),但在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。\n", "\n", "kNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的`k`个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成正比(组合函数)。\n", "\n", "该算法存在的问题:\n", "1. 当样本不平衡时,如一个类的样本数量很大,而其他类样本数量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大数量类的样本占多数。在这种情况下可能会产生误判的结果。因此我们需要减少数量对运行结果的影响。可以采用权值的方法(和该样本距离小的邻居权值大)来改进。\n", "2. 计算量较大,因为对每一个待分类的数据都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。\n", "\n", "kNN可以说是一种最直接的用来分类未知数据的方法。基本通过下面这张图跟文字说明就可以明白kNN是干什么的\n", "![knn](images/knn.png)\n", "\n", "简单来说,kNN可以看成:**有那么一堆你已经知道分类的数据,然后当一个新数据进入的时候,就开始跟训练数据里的每个点求距离,然后挑选这个训练数据最近的K个点,看看这几个点属于什么类型,然后用少数服从多数的原则,给新数据归类**。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 算法步骤:\n", "\n", "输入:\n", "* 训练数据: $T=\\{(x_1,y_1),(x_2,y_2), ..., (x_N,y_N)\\}$, 其中$x_i \\in X=R^n$,$y_i \\in Y = {0, 1, ..., K-1}$,i=1,2...N\n", "* 用户输入数据:$x_u$\n", "\n", "输出:预测的最优类别$y_{pred}$\n", "\n", "\n", "1. 准备数据,对数据进行预处理;\n", "2. 计算测试数据与各个训练数据之间的**距离**;\n", "3. 按照距离的递增关系进行排序;\n", "4. 选取距离最小的`k`个点;\n", "5. 确定前`k`个点所在类别的出现频率;\n", "6. 返回前`k`个点中出现频率最高的类别作为测试数据的预测分类。\n", "\n", "\n", "\n", "**深入思考:**\n", "* 上述的处理过程,难点有哪些?\n", "* 每个处理步骤如何用程序语言来描述?\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1 距离计算\n", "\n", "要度量空间中点距离的话,有好几种度量方式,比如常见的曼哈顿距离计算、欧式距离计算等等。不过通常 KNN 算法中使用的是欧式距离。这里只是简单说一下,拿二维平面为例,二维空间两个点的欧式距离计算公式如下:\n", "$$\n", "d = \\sqrt{(x_2 - x_1)^2 + (y_2 - y_1)^2}\n", "$$\n", "\n", "在二维空间其实就是计算 $(x_1,y_1)$ 和 $(x_2, y_2)$ 的距离。拓展到多维空间,则公式变成:\n", "$$\n", "d(p, q) = \\sqrt{ (p_1-q_1)^2 + (p_1-q_1)^2 + ... + (p_n-q_n)^2 } = \\sqrt{ \\sum_{i=1,n} (p_i-q_i)^2}\n", "$$\n", "\n", "这样我们就明白了如何计算距离。kNN 算法最简单粗暴的就是将 `预测点` 与 `所有点` 距离进行计算,然后保存并排序,选出前面 k 个值看看哪些类别比较多。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 2. 机器学习的思维模型\n", "\n", "![machine learning - methodology](images/ml_methodology.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 生成数据" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# data generation\n", "np.random.seed(314)\n", "\n", "data_size1 = 10000\n", "x1 = np.random.randn(data_size1, 2) + np.array([4,4])\n", "y1 = [0 for _ in range(data_size1)]\n", "\n", "data_size2 = 10000\n", "x2 = np.random.randn(data_size2, 2)*2 + np.array([10,10])\n", "y2 = [1 for _ in range(data_size2)]\n", "\n", "\n", "# all sample data\n", "x = np.concatenate((x1, x2), axis=0)\n", "y = np.concatenate((y1, y2), axis=0)\n", "\n", "data_size_all = data_size1 + data_size2\n", "shuffled_index = np.random.permutation(data_size_all)\n", "x = x[shuffled_index]\n", "y = y[shuffled_index]\n", "\n", "# split train & test\n", "split_index = int(data_size_all*0.7)\n", "x_train = x[:split_index]\n", "y_train = y[:split_index]\n", "x_test = x[split_index:]\n", "y_test = y[split_index:]\n", "\n", "\n", "# plot data\n", "plt.scatter(x_train[:,0], x_train[:,1], c=y_train, marker='.')\n", "plt.title(\"train data\")\n", "plt.show()\n", "plt.scatter(x_test[:,0], x_test[:,1], c=y_test, marker='.')\n", "plt.title(\"test data\")\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 最简单的程序实现" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import operator\n", "\n", "def knn_distance(v1, v2):\n", " return np.sum(np.square(v1-v2))\n", "\n", "def knn_vote(ys):\n", " method = 1\n", " \n", " # method 1\n", " if method == 1:\n", " vote_dict = {}\n", " for y in ys:\n", " if y not in vote_dict.keys():\n", " vote_dict[y] = 1\n", " else:\n", " vote_dict[y] += 1\n", " sorted_vote_dict = sorted(vote_dict.items(), \\\n", " #key=operator.itemgetter(1), \\\n", " key=lambda x:x[1], \\\n", " reverse=True)\n", " \n", " return sorted_vote_dict[0][0]\n", " \n", " # method 2\n", " if method == 2:\n", " maxv = 0\n", " maxk = 0\n", " \n", " vote_dict = {}\n", " for y in ys:\n", " if y not in vote_dict.keys():\n", " vote_dict[y] = 1\n", " else:\n", " vote_dict[y] += 1\n", " \n", " for y in np.unique(ys):\n", " if maxv < vote_dict[y]:\n", " maxv = vote_dict[y]\n", " maxk = y\n", " return maxk\n", " \n", "def knn_predict(x, train_x, train_y, k=3):\n", " dist_arr = [knn_distance(x, train_x[j]) for j in range(len(train_x))]\n", " sorted_index = np.argsort(dist_arr)\n", " top_k_index = sorted_index[:k]\n", " ys=train_y[top_k_index]\n", " return knn_vote(ys)\n", " \n", "\n", "#a = knn_predict(x_train[0], x_train, y_train)\n", "\n", "y_train_est = [knn_predict(x_train[i], x_train, y_train) for i in range(len(x_train))]\n", "print(y_train_est)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Accuracy: 100.000000%\n" ] } ], "source": [ "n_correct = 0\n", "for i in range(len(x_train)):\n", " if y_train_est[i] == y_train[i]:\n", " n_correct += 1\n", "accuracy = n_correct / len(x_train) * 100.0\n", "print(\"Train Accuracy: %f%%\" % accuracy)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test Accuracy: 96.666667%\n", "58 60\n" ] } ], "source": [ "y_test_est = [knn_predict(x_test[i], x_train, y_train, 3) for i in range(len(x_test))]\n", "n_correct = 0\n", "for i in range(len(x_test)):\n", " if y_test_est[i] == y_test[i]:\n", " n_correct += 1\n", "accuracy = n_correct / len(x_test) * 100.0\n", "print(\"Test Accuracy: %f%%\" % accuracy)\n", "print(n_correct, len(x_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 通过类实现kNN程序" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import operator\n", "\n", "class KNN(object):\n", "\n", " def __init__(self, k=3):\n", " self.k = k\n", "\n", " def fit(self, x, y):\n", " self.x = x\n", " self.y = y\n", " return self\n", "\n", " def _square_distance(self, v1, v2):\n", " return np.sum(np.square(v1-v2))\n", "\n", " def _vote(self, ys):\n", " ys_unique = np.unique(ys)\n", " vote_dict = {}\n", " for y in ys:\n", " if y not in vote_dict.keys():\n", " vote_dict[y] = 1\n", " else:\n", " vote_dict[y] += 1\n", " sorted_vote_dict = sorted(vote_dict.items(), key=operator.itemgetter(1), reverse=True)\n", " return sorted_vote_dict[0][0]\n", "\n", " def predict(self, x):\n", " y_pred = []\n", " for i in range(len(x)):\n", " dist_arr = [self._square_distance(x[i], self.x[j]) for j in range(len(self.x))]\n", " sorted_index = np.argsort(dist_arr)\n", " top_k_index = sorted_index[:self.k]\n", " y_pred.append(self._vote(ys=self.y[top_k_index]))\n", " return np.array(y_pred)\n", "\n", " def score(self, y_true=None, y_pred=None):\n", " if y_true is None and y_pred is None:\n", " y_pred = self.predict(self.x)\n", " y_true = self.y\n", " score = 0.0\n", " for i in range(len(y_true)):\n", " if y_true[i] == y_pred[i]:\n", " score += 1\n", " score /= len(y_true)\n", " return score" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train accuracy: 100.000000 %\n", "test accuracy: 96.666667 %\n" ] } ], "source": [ "# data preprocessing\n", "#x_train = (x_train - np.min(x_train, axis=0)) / (np.max(x_train, axis=0) - np.min(x_train, axis=0))\n", "#x_test = (x_test - np.min(x_test, axis=0)) / (np.max(x_test, axis=0) - np.min(x_test, axis=0))\n", "\n", "# knn classifier\n", "clf = KNN(k=3)\n", "train_acc = clf.fit(x_train, y_train).score() * 100.0\n", "\n", "y_test_pred = clf.predict(x_test)\n", "test_acc = clf.score(y_test, y_test_pred) * 100.0\n", "\n", "print('train accuracy: %f %%' % train_acc)\n", "print('test accuracy: %f %%' % test_acc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. sklearn program" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Feature dimensions: (1797, 64)\n", "Label dimensions: (1797,)\n" ] } ], "source": [ "#% matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "from sklearn import datasets, neighbors, linear_model\n", "\n", "# load data\n", "digits = datasets.load_digits()\n", "X_digits = digits.data\n", "y_digits = digits.target\n", "\n", "print(\"Feature dimensions: \", X_digits.shape)\n", "print(\"Label dimensions: \", y_digits.shape)\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAABLCAYAAABZX83EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUfElEQVR4nO2deZwV1ZXHv6c3GrqhBRob2QTEFiFRVELUBHEZI8aZgJpPNJqYMSoJDE6Mmo0xH0liJBMT0bgQMUjc4pL5BJ24O1FQFJeOEAhKE1lkX5q19+2d+aNev6r7pJvmvdevKnK+n8/79L3v1rv161u3blWdOudeUVUMwzCM6JITtgDDMAyjY2ygNgzDiDg2UBuGYUQcG6gNwzAijg3UhmEYEccGasMwjIhjA7VhGEbEicRALSJ9RGSBiNSKyEciclkIGqaLSIWINIrI77O9/4CObiIyL94O1SKyTETOD0nLIyKyVUT2i8hqEbk6DB0BPceKSIOIPBLS/hfG918T/1SGoSOu5VIR+SB+zqwRkfFZ3n9N0qdVRO7KpoaAlqEi8pyI7BGRbSJyt4jkhaDjeBF5RUT2iciHInJhpuqOxEAN3AM0AWXA5cAcERmdZQ1bgFuAB7K832TygI3ABKAEuAl4UkSGhqBlFjBUVXsBXwJuEZFTQtDRxj3AuyHuH2C6qhbHP8eFIUBEzgX+G7gS6AmcAazNpoZAGxQD/YF64I/Z1BDgXmAHcBQwBu/cmZZNAfELw9PAM0AfYArwiIiUZ6L+0AdqESkCLgZ+rKo1qroY+F/g69nUoap/UtWngF3Z3O8BdNSq6kxVXa+qMVV9BlgHZH2AVNWVqtrYlo1/jsm2DvDuIIG9wF/C2H/E+AnwU1V9K95HNqvq5hD1XIw3UL4e0v6HAU+qaoOqbgNeALJ9ozcSGADMVtVWVX0FeIMMjWOhD9RAOdCiqqsD3/2N7Dd0JBGRMrw2WhnS/u8VkTpgFbAVeC4EDb2AnwLXZ3vfB2CWiFSJyBsicma2dy4iucBYoF/88XpT/FG/e7a1BPgG8JCGNx/FHcClItJDRAYC5+MN1mEjwKcyUVEUBupiYH/Sd/vwHukOa0QkH3gUeFBVV4WhQVWn4R2L8cCfgMaOf9El/AyYp6qbQth3kB8Aw4GBwFzgzyKS7SeMMiAf+DLeMRkDnIRnIss6InI0nqnhwTD2H+c1vBu7/cAmoAJ4KssaKvGeKr4nIvki8gW8dumRicqjMFDXAL2SvusFVIegJTKISA7wMJ7tfnqYWuKPcouBQcDUbO5bRMYA/wLMzuZ+D4Sqvq2q1araqKoP4j3afjHLMurjf+9S1a2qWgXcHoKONr4OLFbVdWHsPH6evIB3E1EElAK98Wz4WUNVm4HJwAXANuAG4Em8C0faRGGgXg3kicixge9OJKRH/SggIgLMw7t7ujjeCaJAHtm3UZ8JDAU2iMg24EbgYhF5L8s6DoTiPd5mb4eqe/BO/qCZIcwpMK8g3LvpPsAQ4O74BXQXMJ8QLlyqulxVJ6hqX1U9D+/p651M1B36QK2qtXhXw5+KSJGIfA6YhHc3mTVEJE9ECoFcIFdECsNw8YkzBzge+DdVrT/Yxl2BiBwZdwErFpFcETkP+CrZf5k3F+/iMCb++S3wLHBeNkWIyBEicl5bvxCRy/G8LcKwhc4Hro0fo97Ad/G8DbKKiJyOZwYKy9uD+BPFOmBq/LgcgWczX55tLSJyQrx/9BCRG/G8UH6fkcpVNfQP3lXxKaAW2ABcFoKGmfieDW2fmSHoODq+7wY8s1Db5/Is6+gHLMLztNgPrACuiUBfmQk8EsJ+++G5BlbH2+Qt4NyQ2iAfzyVtL95j9m+AwhB03Ac8HIE+MQZYCOwBqvBMDmUh6LgtrqEGeB4Ykam6Jb4DwzAMI6KEbvowDMMwOsYGasMwjIhjA7VhGEbE6dRALSITRaQyHgn1w64WZTpMh+kwHZ9UHalw0JeJ8ZDV1cC5eP6b7wJfVdX32/tNgXTTQooOWNZS6n7fv//uRHpz7RFOWeEm331YValp2U0PihFyqKOaQorIJZcGamnSxo/5s3ak42PbjvSvWd1yWpyyvdv9IElVpWH31i7TETvC327o4O1O2bZmPy5IVdlTuTdjOpoGut9/qu/ORHp3LNcp21Xpb9vVx0XyfA/J2HD3vkJWN/k6UGrZnzEdwf4AUNtckEjnr2loV2+mdXSkK7mfVr/vl2VaR9MA93sNdInSnm5s2lF5fvuoKisqmxg6LI+8PKhcqXTP7UWu5FHfWk1TrP6QdDQOdQP9Bhf748fGfX2dssKtfhCtqlLTmrl+quUFTj54LJpWxQ74m4PRng7wAhgOxjjgQ1VdCyAij+P5Obc7UBdSxGflnAOWVV18mpP/3g2PJ9I//uskp6z8+q2J9J6mbfxj12JOjs/muC4eUT1MRvK2Hti1tyMdyQx40B+Mj+2xwyl76vazE+maHevZ+eyTXaaj7uzPJtLz7rjdKZu1dWIivXPFDt6+uiJjOtZd6x6Xd74xJ5F+vLq3U/bwhHGJdFcfl9zSIxPp+nvd6SwKzv0okd6ru1jL+xnTEewPAO9sHpJID7q4/VisTOvoSFdyP110gt8+mdax4VunO/mmEn8wuuqcV52yGaX+7K9LKur5wW07+d0j3iD6pVFeOw4vPoUlVQd2v+5Ix+qbxzr5X473x48bnvmaU3bcL/wJBfc2beMfu9/IWHs03Xu0kx/a079gbDk1taDq9nRA50wfA/Gm3WxjU/w7BxGZIt58zhXNXTAdRGNrDYX4HbGQ7jTy8ViQrtbRXLcvEjrqdtRFQkdUjksj9aYjgjo2b2ul/wD/9rswt5iGWG3WdTTEaiPRHqmSsZeJqjpXVceq6th8umWqWtNhOkyH6TjsdCTTGdPHZmBwID8o/l1KBE0dAJf23JNI33FEjVP27HsvJtJLKuo5f2oRVRd4j+gN962iG+6jcDqsr+6TSM8f4k6re/8Z/uIZjQOKaH7VvxI3UJ+WjtiEk5z86/fcl0ivTprhY1LfpYl05YhaVpCejtVzfBPGrLPd4/KpO/151//+nXudsrvGD02km6ug4ZVlaenoiHVTRyTSTX93bX8j8E0f3ehOQ5rtESTY1pDUJ7a42z5VW5xIV75Xy6++kjkde/7dNUm9OMQ3SR3zxLedshG8lUhnuj2SKdjn3+M9f/OZTtnL00Ym0vvrt7B93SJmbfUi/uv2vUsB0Fq/A1XXxt4ZzhzV/qI6v/5Xd+Gfp0/zz62cFXlsvDq99sgd7a8T8eroJ9rfMKl/3Frlri8RNFF1ls7cUb8LHCsiw0SkALgUb2L/rPKZMYU07quicf8uYq0tbGcj/Tgq2zIoGDaIemqo11piGgtNx4gTekRCR3GfwZHQ0YvekdARleMSlfboeVx/qjfup3pLNa3NraHp6Ht8aSTaI1UOeketqi0iMh14EW/CogdUNesz2+XlCYM+fxFrn5uLqjKYQRRLSbZlILm5HMcYlvI6ijKAoaHoyM2TSOiQnGi0R47kcJyGryMqxyUq7SG5OYy78VT+7z9fQmNKGUPCaY+8nEgcl1Tp1Oxwqvocaazs0XK2v4rUpT2XOWXnT7w0kS5Z7s6N/5XF7hvXlknHUDbp+wAMm5re7IHJJof7yu8O5FyXnF4rXFecUjmK0gxdjddOdu1gwcekeX85yylbc8lvnfwcGZGWjpFz/PUaHv7JOKfspkWPJdLJXh/Ff3zbzWewPXLLjnTyX7/IfxP+xHy3PwQfRQHKOI4yzgSgdWV6686+X+++L59c5Ne3utl9GfZfyy938kf330kZJ3o6trueGYfK5Otfabds+FMdv+zKZD8dMvPNdss+nH2qk7+qzD2PF/+ynNPxlg5slfTaY+H77jF/p6R9b5y7PnInNrzqousZw2QAeixw+3BnaC5tfw2AKzf45tGghxDAz0942skvYgSHikUmGoZhRBwbqA3DMCKODdSGYRgRJysrmDT09Xdz045PO2WxJLt0kHdXZHbVpw0z/eiqp6+8zSkrz28/hHfgS7ucfGsGNQWjpwCe2ODbYZ+/ztV41srLnHxBwD0tFZy2P2GkUxZ0m/zKWtc2nNff7TYt29xQ93QIuuMB3FGyIJFeNNt1a/rgATdKLWefr2vEd9PT8fJ2tz2C0XbJfSW2wn0p1bo9c+/aR3V3PWGD7zByFi1N3jyj1F3oR8luOaP9Fceev+jXHdbzxGV+/+k/Oz0b9YgH3bPv5cceTaSvfGu8U/Z+U5mT77l6byKdyjmcv6p9r+Ttk/y+Oe7pDU7ZqILk88Ns1IZhGJ84bKA2DMOIONkxffT2rwePLnEjrco7WKQ3r6TJybfsK2hny84RdDG6bs6FTtlzS19q93fJbjnpXt2CLmiVPxzulF11TvsTs3T/mjs3QSZNMMkmqAtO9teOPemFpFCrpOVcl04ckEinYgYJRt99MMWNghy9ZEoiPShpYfp1E3/n5E+8bRqZIjjhE8D4C7+VSFed6M4mmKz5eHwdHbm1dYbkx+and/lupRtmumbEYX9MMtGl6aIYNBUMmebOGHhf+R/a/d1V113v5PsvSK8NgjT0aX8MSI4o/uK5lzj5dNsj6GqZHG0YHD+GvXC1U/ajo9wTJuhW2llNdkdtGIYRcWygNgzDiDidMn2IyHqgGu9pu0VVx3b8i65h84xZ5BR2gxxhjzZ0eg7fTLNYnyOXPARByDnsdSzc8RB5ko8goLHQdFT+7mfk5HdDcnL4SOtD07F29s/IKfB0bA1RR1T6h+lIn0OxUZ+lqlWp7KRwjz/r2Wc+vcYp2xcU0991p7lk1F+d/G0xGHjVVHKLihn2oyWpSEmJHScnzXa1EE5hAgWS2jSIH8zyQ0zXTfxtu9uNm3Gjk++9/eP/czo6OiJoaw7aoAF2PeBOqt9cspC+M64lt7iI8hRC+7vt8/tHcnj2ytN896tbl7t2wWRya1o4rXQyBTnd0w5VTiYYclzKZzvYEjRX6XfTNeT2LKL8mxVp7fd/9p3s5IN22Fsvcv/HGVNce2fR0AJOOmUaBQVFKbnyBe2nBee6ZeVbfBfFcTOmOmW9F2S2nwanewjOLgnuDIKFQ9wJ+y9/zG37RSfnM+6Yb1KQ1yNte3XyDHivTrgykS5f5O73vAe+4+SH3uGvnpTcru1hpg/DMIyI09k7agVeEhEF7lPVuckbiMgUYApAIe1PXpIWImyZNxdEyNcjGSTDD7BJFnQAS3kdFAYy3HQI7LjzfhChh5aFpkMEKnb/GUEYoINDbA9hx6/mhd4eiLBs+XwEYaD2P+z7qQAV6/8Q7x8DQm2PQ6WzA/XnVXWziBwJvCwiq1T1teAG8cF7LkAv6dPxirkpMvDb08krKaGlpppNt8ymSHvSW/o522RDx1jOolC606QNvMfrh72Oshunkde7hNb9NWz6/p2h6RjX50IKc4tpbK2jYueC8Npjxrf99vjOXaHpOOWka+jWrYSmphqWvXnvYd9Pxw2/gsL8XjS21FKxan5oOlKhs9Ocbo7/3SEiC/AWvH2t41/59Kr0LdE3D3rGKbtiiu9zmT95Jx1x7C/89XRzGMB+dtObfh38omsoFM8+VSCF9NND1xEMg711rGt3DYYqv3PrHKfsrMvdxX/rHx2QWMOj3/w1abVHcLUXgAGv+CHDQT94gIdGuYvuTt47FWgir6SAfikcl6D999oFn3PKgvbJex662ykL+lgDDKpaSSt15EFKOoIkr6wStKOP+EG76zoDMHRx22pBJWxNU8fDf3JfeAXt0Mlh7l8uec/Jb76kzV+/G/3eTE/H6qRw/dXNbyTSpc+7752S/fvTPV+CodvJ7zCCUzA0j3Snpp3xmGuHnje1bdrg3vS7LrPjR/AdQHJbvXjOnU4+6Gfe2WkgDmqjFpEiEenZlga+APy9U7VnkKa6FlrUW5uqVVvYzXaKyP7E37HmxkjoaK1vprXZC0JobW4MTUddXYxYvTcvcqyhKbzj0tAUieNSWxfzj0tLeO1RVxcj1uDpiDWG1z9aNRrnbW1E2iNVOnNHXQYsEJG27f+gqi90/JPMU7u7kQreBgVF6c9gSqV/tmXQUltDBQtD19G0p5bKZ707TNUYAzkqFB27q2JsmXm/p6M1xpCQdLTsrY3Ecdm+s5UVr3mRihqLMSCk9ti1M8bW39zjZWLhHZdGGljOkkgclyi0R6p0ZimutRBfriJFguHJl8y5wSm76QZ/JZE71riPee+OCYbq9uJU6aQvSydIXnnjrJW+WeHV0e6KDC2f9003OeRx6vz0dAQfkzpy82m5abdbFtQ1Gobd7rv9DEvTDSx/rxsWfe0tj7ezJUx+03XHOmPjMj/T/iRrKZFfVZdIJ89a1+eR4kCumOEZ7B87z3BXFk4OVw8yeom7wsvp+wOh7mm2x7A5H7r5IX54cvIj9bdWu7Mrji/3F2zO2Z7eTHvXjHXDs792s+86eiC30TZ6SDGnkt5xCZ6ryf/jq0v9cyLZLJI82+TZMX8KhnTdN5PNG8FFdyf0cNvqP66Y7uR7LDr01WXMPc8wDCPi2EBtGIYRcWygNgzDiDiimnlXQRHZCdQCKYWcJ1HaiXqOVtWP+dmYjkjr+KiTdZgO0/FJ0NEZLQfUAYCqdskHqIhCPaYjmjqsDqvjcKoj3XrM9GEYhhFxbKA2DMOIOF05UH9s4qaQ6jEdmf19JuuxOqyOw6WOtOrpkpeJhmEYRuYw04dhGEbEsYHaMAwj4nTJQC0iE0WkUkQ+FJEfplHPehFZISLLROSQJ7MwHabDdJiOf3YdQOb9qIFcYA0wHCgA/gaMSrGu9UCp6TAdpsN0HI462j5dcUc9DvhQVdeqahPwODDpIL/pCkyH6TAdpuOfXQfQNaaPgcDGQH5T/LtUaFur8a/xtcxMh+kwHabjcNIBdH7NxLA46FqNpsN0mA7T8UnX0RV31JuBwYH8oPh3h4wG1moE2tZqNB2mw3SYjsNFR6KSjH7w7tLXAsPwjfCjU6inCOgZSL8JTDQdpsN0mI7DRUfbJ+OmD1VtEZHpwIt4b04fUNWVB/nZgUhrrUbTYTpMh+n4Z9fRhoWQG4ZhRByLTDQMw4g4NlAbhmFEHBuoDcMwIo4N1IZhGBHHBmrDMIyIYwO1YRhGxLGB2jAMI+L8P6Rg4NXcREyMAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot sample images\n", "nplot = 10\n", "fig, axes = plt.subplots(nrows=1, ncols=nplot)\n", "\n", "for i in range(nplot):\n", " img = X_digits[i].reshape(8, 8)\n", " axes[i].imshow(img)\n", " axes[i].set_title(y_digits[i])\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# split train / test data\n", "n_samples = len(X_digits)\n", "n_train = int(0.4 * n_samples)\n", "\n", "X_train = X_digits[:n_train]\n", "y_train = y_digits[:n_train]\n", "X_test = X_digits[n_train:]\n", "y_test = y_digits[n_train:]\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KNN score: 0.953661\n", "LogisticRegression score: 0.927711\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/bushuhui/anaconda3/envs/dl/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py:765: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" ] } ], "source": [ "# do KNN classification\n", "knn = neighbors.KNeighborsClassifier()\n", "logistic = linear_model.LogisticRegression()\n", "\n", "print('KNN score: %f' % knn.fit(X_train, y_train).score(X_test, y_test))\n", "print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. 深入思考\n", "\n", "* 如果输入的数据非常多,怎么快速进行距离计算?\n", " - [kd-tree](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html#sklearn.neighbors.KDTree) \n", " - Fast Library for Approximate Nearest Neighbors (FLANN)\n", " - [PyNNDescent for fast Approximate Nearest Neighbors](https://pynndescent.readthedocs.io/en/latest/)\n", "* 如何选择最好的`k`?\n", " - https://zhuanlan.zhihu.com/p/143092725\n", "* kNN存在的问题?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 参考资料\n", "* [Digits Classification Exercise](http://scikit-learn.org/stable/auto_examples/exercises/plot_digits_classification_exercise.html)\n", "* [knn算法的原理与实现](https://zhuanlan.zhihu.com/p/36549000)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 2 }