You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

consolidated-new-metrics.ipynb 31 kB

first commit Former-commit-id: 08bc23ba02cffbce3cf63962390a65459a132e48 [formerly 0795edd4834b9b7dc66db8d10d4cbaf42bbf82cb] [formerly b5010b42541add7e2ea2578bf2da537efc457757 [formerly a7ca09c2c34c4fc8b3d8e01fcfa08eeeb2cae99d]] [formerly 615058473a2177ca5b89e9edbb797f4c2a59c7e5 [formerly 743d8dfc6843c4c205051a8ab309fbb2116c895e] [formerly bb0ea98b1e14154ef464e2f7a16738705894e54b [formerly 960a69da74b81ef8093820e003f2d6c59a34974c]]] [formerly 2fa3be52c1b44665bc81a7cc7d4cea4bbf0d91d5 [formerly 2054589f0898627e0a17132fd9d4cc78efc91867] [formerly 3b53730e8a895e803dfdd6ca72bc05e17a4164c1 [formerly 8a2fa8ab7baf6686d21af1f322df46fd58c60e69]] [formerly 87d1e3a07a19d03c7d7c94d93ab4fa9f58dada7c [formerly f331916385a5afac1234854ee8d7f160f34b668f] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18 [formerly 386086f05aa9487f65bce2ee54438acbdce57650]]]] Former-commit-id: a00aed8c934a6460c4d9ac902b9a74a3d6864697 [formerly 26fdeca29c2f07916d837883983ca2982056c78e] [formerly 0e3170d41a2f99ecf5c918183d361d4399d793bf [formerly 3c12ad4c88ac5192e0f5606ac0d88dd5bf8602dc]] [formerly d5894f84f2fd2e77a6913efdc5ae388cf1be0495 [formerly ad3e7bc670ff92c992730d29c9d3aa1598d844e8] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18]] Former-commit-id: 3c19c9fae64f6106415fbc948a4dc613b9ee12f8 [formerly 467ddc0549c74bb007e8f01773bb6dc9103b417d] [formerly 5fa518345d958e2760e443b366883295de6d991c [formerly 3530e130b9fdb7280f638dbc2e785d2165ba82aa]] Former-commit-id: 9f5d473d42a435ec0d60149939d09be1acc25d92 [formerly be0b25c4ec2cde052a041baf0e11f774a158105d] Former-commit-id: 9eca71cb73ba9edccd70ac06a3b636b8d4093b04
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Hamming loss"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": 1,
  13. "metadata": {},
  14. "outputs": [],
  15. "source": [
  16. "import pandas as pd\n",
  17. "from io import StringIO\n",
  18. "from sklearn.metrics import hamming_loss"
  19. ]
  20. },
  21. {
  22. "cell_type": "code",
  23. "execution_count": 2,
  24. "metadata": {},
  25. "outputs": [],
  26. "source": [
  27. "def hammingLoss(y_true, y_pred):\n",
  28. " \"\"\"\n",
  29. " Computes the hamming loss. Used for multiclass and multilabel classification.\n",
  30. " \"\"\"\n",
  31. " from sklearn import preprocessing\n",
  32. " from sklearn.metrics import hamming_loss\n",
  33. " import numpy as np\n",
  34. " from itertools import chain\n",
  35. "\n",
  36. " def to_2d_array(df):\n",
  37. " MULTI_CLASS_CONDITION=True\n",
  38. " _dict = {}\n",
  39. " for [index,value] in df.as_matrix():\n",
  40. " if index in _dict.keys():\n",
  41. " _dict[index].append(value)\n",
  42. " MULTI_CLASS_CONDITION = False\n",
  43. " else:\n",
  44. " _dict[index]=[value]\n",
  45. " return list(_dict.keys()), list(_dict.values()), MULTI_CLASS_CONDITION\n",
  46. " \n",
  47. " (y_true_keys, y_true_mat, b_multiclass) = to_2d_array(y_true)\n",
  48. " (y_pred_keys, y_pred_mat, b_multiclass) = to_2d_array(y_pred)\n",
  49. " \n",
  50. " assert y_true_keys==y_pred_keys\n",
  51. " \n",
  52. " if b_multiclass:\n",
  53. "# print('this is a multiclass case')\n",
  54. " y_true_label_encoded = np.array(y_true_mat).ravel()\n",
  55. " y_pred_label_encoded = np.array(y_pred_mat).ravel()\n",
  56. " else: # MULTI_LABEL_CONDITION\n",
  57. "# print('this is a multilabel case')\n",
  58. " y_true_classes=(set(list(chain.from_iterable(y_true_mat))))\n",
  59. " y_pred_classes=(set(list(chain.from_iterable(y_pred_mat))))\n",
  60. " all_classes = list(y_true_classes.union(y_pred_classes))\n",
  61. " lb = preprocessing.MultiLabelBinarizer(classes=all_classes)\n",
  62. " y_true_label_encoded = lb.fit_transform(y_true_mat)\n",
  63. " y_pred_label_encoded = lb.transform(y_pred_mat)\n",
  64. " return hamming_loss(y_true_label_encoded, y_pred_label_encoded)"
  65. ]
  66. },
  67. {
  68. "cell_type": "code",
  69. "execution_count": 3,
  70. "metadata": {},
  71. "outputs": [
  72. {
  73. "name": "stderr",
  74. "output_type": "stream",
  75. "text": [
  76. "/home/svattam/miniconda3/envs/automl/lib/python3.6/site-packages/ipykernel_launcher.py:13: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n",
  77. " del sys.path[0]\n"
  78. ]
  79. },
  80. {
  81. "data": {
  82. "text/plain": [
  83. "0.26666666666666666"
  84. ]
  85. },
  86. "execution_count": 3,
  87. "metadata": {},
  88. "output_type": "execute_result"
  89. }
  90. ],
  91. "source": [
  92. "# Testcase 1: MultiLabel, typical\n",
  93. "y_true = pd.read_csv(StringIO(\"\"\"\n",
  94. "d3mIndex,class_label\n",
  95. "3,happy-pleased\n",
  96. "3,relaxing-calm\n",
  97. "7,amazed-suprised\n",
  98. "7,happy-pleased\n",
  99. "13,quiet-still\n",
  100. "13,sad-lonely\n",
  101. "\"\"\"))\n",
  102. "\n",
  103. "y_pred = pd.read_csv(StringIO(\"\"\"\n",
  104. "d3mIndex,class_label\n",
  105. "3,happy-pleased\n",
  106. "3,sad-lonely\n",
  107. "7,amazed-suprised\n",
  108. "7,happy-pleased\n",
  109. "13,quiet-still\n",
  110. "13,happy-pleased\n",
  111. "\"\"\"))\n",
  112. "\n",
  113. "hammingLoss(y_true, y_pred)"
  114. ]
  115. },
  116. {
  117. "cell_type": "code",
  118. "execution_count": 4,
  119. "metadata": {},
  120. "outputs": [
  121. {
  122. "name": "stderr",
  123. "output_type": "stream",
  124. "text": [
  125. "/home/svattam/miniconda3/envs/automl/lib/python3.6/site-packages/ipykernel_launcher.py:13: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n",
  126. " del sys.path[0]\n"
  127. ]
  128. },
  129. {
  130. "data": {
  131. "text/plain": [
  132. "0.0"
  133. ]
  134. },
  135. "execution_count": 4,
  136. "metadata": {},
  137. "output_type": "execute_result"
  138. }
  139. ],
  140. "source": [
  141. "# Testcase 2: MultiLabel, Zero loss\n",
  142. "y_true = pd.read_csv(StringIO(\"\"\"\n",
  143. "d3mIndex,class_label\n",
  144. "3,happy-pleased\n",
  145. "3,relaxing-calm\n",
  146. "7,amazed-suprised\n",
  147. "7,happy-pleased\n",
  148. "13,quiet-still\n",
  149. "13,sad-lonely\n",
  150. "\"\"\"))\n",
  151. "\n",
  152. "y_pred = pd.read_csv(StringIO(\"\"\"\n",
  153. "d3mIndex,class_label\n",
  154. "3,happy-pleased\n",
  155. "3,relaxing-calm\n",
  156. "7,amazed-suprised\n",
  157. "7,happy-pleased\n",
  158. "13,quiet-still\n",
  159. "13,sad-lonely\n",
  160. "\"\"\"))\n",
  161. "\n",
  162. "hammingLoss(y_true, y_pred)"
  163. ]
  164. },
  165. {
  166. "cell_type": "code",
  167. "execution_count": 5,
  168. "metadata": {},
  169. "outputs": [
  170. {
  171. "name": "stderr",
  172. "output_type": "stream",
  173. "text": [
  174. "/home/svattam/miniconda3/envs/automl/lib/python3.6/site-packages/ipykernel_launcher.py:13: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n",
  175. " del sys.path[0]\n"
  176. ]
  177. },
  178. {
  179. "data": {
  180. "text/plain": [
  181. "1.0"
  182. ]
  183. },
  184. "execution_count": 5,
  185. "metadata": {},
  186. "output_type": "execute_result"
  187. }
  188. ],
  189. "source": [
  190. "# Testcase 3: MultiLabel, Complete loss\n",
  191. "y_true = pd.read_csv(StringIO(\"\"\"\n",
  192. "d3mIndex,class_label\n",
  193. "3,happy-pleased\n",
  194. "3,relaxing-calm\n",
  195. "7,amazed-suprised\n",
  196. "7,happy-pleased\n",
  197. "13,quiet-still\n",
  198. "13,sad-lonely\n",
  199. "\"\"\"))\n",
  200. "\n",
  201. "y_pred = pd.read_csv(StringIO(\"\"\"\n",
  202. "d3mIndex,class_label\n",
  203. "3,ecstatic\n",
  204. "3,sad-lonely\n",
  205. "3,quiet-still\n",
  206. "3,amazed-suprised\n",
  207. "7,ecstatic\n",
  208. "7,sad-lonely\n",
  209. "7,relaxing-calm\n",
  210. "7,quiet-still\n",
  211. "13,ecstatic\n",
  212. "13,happy-pleased\n",
  213. "13,relaxing-calm\n",
  214. "13,amazed-suprised\n",
  215. "\"\"\"))\n",
  216. "\n",
  217. "hammingLoss(y_true, y_pred)"
  218. ]
  219. },
  220. {
  221. "cell_type": "code",
  222. "execution_count": 6,
  223. "metadata": {},
  224. "outputs": [
  225. {
  226. "name": "stderr",
  227. "output_type": "stream",
  228. "text": [
  229. "/home/svattam/miniconda3/envs/automl/lib/python3.6/site-packages/ipykernel_launcher.py:13: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n",
  230. " del sys.path[0]\n"
  231. ]
  232. },
  233. {
  234. "data": {
  235. "text/plain": [
  236. "0.2"
  237. ]
  238. },
  239. "execution_count": 6,
  240. "metadata": {},
  241. "output_type": "execute_result"
  242. }
  243. ],
  244. "source": [
  245. "# Testcase 4: Multiclass, Typical\n",
  246. "y_true = pd.read_csv(StringIO(\"\"\"\n",
  247. "d3mIndex,species\n",
  248. "1,versicolor\n",
  249. "2,versicolor\n",
  250. "16,virginica\n",
  251. "17,setosa\n",
  252. "22,versicolor\n",
  253. "26,versicolor\n",
  254. "30,versicolor\n",
  255. "31,virginica\n",
  256. "33,versicolor\n",
  257. "37,virginica\n",
  258. "\"\"\"))\n",
  259. "\n",
  260. "y_pred = pd.read_csv(StringIO(\"\"\"\n",
  261. "d3mIndex,species\n",
  262. "1,setosa\n",
  263. "2,versicolor\n",
  264. "16,virginica\n",
  265. "17,setosa\n",
  266. "22,versicolor\n",
  267. "26,virginica\n",
  268. "30,versicolor\n",
  269. "31,virginica\n",
  270. "33,versicolor\n",
  271. "37,virginica\n",
  272. "\"\"\"))\n",
  273. "\n",
  274. "hammingLoss(y_true, y_pred)"
  275. ]
  276. },
  277. {
  278. "cell_type": "code",
  279. "execution_count": 7,
  280. "metadata": {},
  281. "outputs": [
  282. {
  283. "name": "stderr",
  284. "output_type": "stream",
  285. "text": [
  286. "/home/svattam/miniconda3/envs/automl/lib/python3.6/site-packages/ipykernel_launcher.py:13: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n",
  287. " del sys.path[0]\n"
  288. ]
  289. },
  290. {
  291. "data": {
  292. "text/plain": [
  293. "0.0"
  294. ]
  295. },
  296. "execution_count": 7,
  297. "metadata": {},
  298. "output_type": "execute_result"
  299. }
  300. ],
  301. "source": [
  302. "# Testcase 5: Multiclass, Zero loss\n",
  303. "y_true = pd.read_csv(StringIO(\"\"\"\n",
  304. "d3mIndex,species\n",
  305. "1,versicolor\n",
  306. "2,versicolor\n",
  307. "16,virginica\n",
  308. "17,setosa\n",
  309. "22,versicolor\n",
  310. "26,versicolor\n",
  311. "30,versicolor\n",
  312. "31,virginica\n",
  313. "33,versicolor\n",
  314. "37,virginica\n",
  315. "\"\"\"))\n",
  316. "\n",
  317. "y_pred = pd.read_csv(StringIO(\"\"\"\n",
  318. "d3mIndex,species\n",
  319. "1,versicolor\n",
  320. "2,versicolor\n",
  321. "16,virginica\n",
  322. "17,setosa\n",
  323. "22,versicolor\n",
  324. "26,versicolor\n",
  325. "30,versicolor\n",
  326. "31,virginica\n",
  327. "33,versicolor\n",
  328. "37,virginica\n",
  329. "\"\"\"))\n",
  330. "\n",
  331. "hammingLoss(y_true, y_pred)"
  332. ]
  333. },
  334. {
  335. "cell_type": "code",
  336. "execution_count": 8,
  337. "metadata": {},
  338. "outputs": [
  339. {
  340. "name": "stderr",
  341. "output_type": "stream",
  342. "text": [
  343. "/home/svattam/miniconda3/envs/automl/lib/python3.6/site-packages/ipykernel_launcher.py:13: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n",
  344. " del sys.path[0]\n"
  345. ]
  346. },
  347. {
  348. "data": {
  349. "text/plain": [
  350. "1.0"
  351. ]
  352. },
  353. "execution_count": 8,
  354. "metadata": {},
  355. "output_type": "execute_result"
  356. }
  357. ],
  358. "source": [
  359. "# Testcase 6: Multiclass, Complete loss\n",
  360. "y_true = pd.read_csv(StringIO(\"\"\"\n",
  361. "d3mIndex,species\n",
  362. "1,versicolor\n",
  363. "2,versicolor\n",
  364. "16,versicolor\n",
  365. "17,virginica\n",
  366. "22,versicolor\n",
  367. "26,versicolor\n",
  368. "30,versicolor\n",
  369. "31,virginica\n",
  370. "33,versicolor\n",
  371. "37,virginica\n",
  372. "\"\"\"))\n",
  373. "\n",
  374. "y_pred = pd.read_csv(StringIO(\"\"\"\n",
  375. "d3mIndex,species\n",
  376. "1,setosa\n",
  377. "2,setosa\n",
  378. "16,setosa\n",
  379. "17,setosa\n",
  380. "22,setosa\n",
  381. "26,setosa\n",
  382. "30,setosa\n",
  383. "31,setosa\n",
  384. "33,setosa\n",
  385. "37,setosa\n",
  386. "\"\"\"))\n",
  387. "\n",
  388. "hammingLoss(y_true, y_pred)"
  389. ]
  390. },
  391. {
  392. "cell_type": "code",
  393. "execution_count": null,
  394. "metadata": {},
  395. "outputs": [],
  396. "source": []
  397. },
  398. {
  399. "cell_type": "code",
  400. "execution_count": null,
  401. "metadata": {},
  402. "outputs": [],
  403. "source": []
  404. },
  405. {
  406. "cell_type": "code",
  407. "execution_count": null,
  408. "metadata": {},
  409. "outputs": [],
  410. "source": []
  411. },
  412. {
  413. "cell_type": "markdown",
  414. "metadata": {},
  415. "source": [
  416. "# RMSE"
  417. ]
  418. },
  419. {
  420. "cell_type": "code",
  421. "execution_count": 9,
  422. "metadata": {},
  423. "outputs": [],
  424. "source": [
  425. "def rootMeanSquaredError(y_true, y_pred):\n",
  426. " \"\"\"\n",
  427. " Computes the root mean squared error, for both univariate and multivariate case\n",
  428. " \"\"\"\n",
  429. " import numpy as np\n",
  430. " from sklearn.metrics import mean_squared_error\n",
  431. " from math import sqrt\n",
  432. " \n",
  433. " rmse = None\n",
  434. " \n",
  435. " # perform some checks\n",
  436. " assert 'd3mIndex' in y_true.columns\n",
  437. " assert 'd3mIndex' in y_pred.columns\n",
  438. " assert y_true.shape == y_pred.shape\n",
  439. " \n",
  440. " # preprocessing\n",
  441. " y_true.set_index('d3mIndex', inplace=True)\n",
  442. " y_pred.set_index('d3mIndex', inplace=True)\n",
  443. " \n",
  444. " # determine the dimension\n",
  445. " y_true_dim=y_true.shape[1]\n",
  446. " \n",
  447. " # univariate case\n",
  448. " if y_true_dim == 1: \n",
  449. " y_true_array = y_true.as_matrix().ravel()\n",
  450. " y_pred_array = y_pred.as_matrix().ravel()\n",
  451. " mse = mean_squared_error(y_true, y_pred)\n",
  452. " rmse = sqrt(mse)\n",
  453. " \n",
  454. " # multivariate case\n",
  455. " elif y_true_dim > 1:\n",
  456. " y_true_array = y_true.as_matrix()\n",
  457. " y_pred_array = y_pred.as_matrix()\n",
  458. " mse = mean_squared_error(y_true_array, y_pred_array, multioutput='uniform_average')\n",
  459. " rmse = sqrt(mse)\n",
  460. " \n",
  461. " return rmse"
  462. ]
  463. },
  464. {
  465. "cell_type": "code",
  466. "execution_count": 10,
  467. "metadata": {},
  468. "outputs": [
  469. {
  470. "name": "stderr",
  471. "output_type": "stream",
  472. "text": [
  473. "/home/svattam/miniconda3/envs/automl/lib/python3.6/site-packages/ipykernel_launcher.py:25: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n",
  474. "/home/svattam/miniconda3/envs/automl/lib/python3.6/site-packages/ipykernel_launcher.py:26: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n"
  475. ]
  476. },
  477. {
  478. "data": {
  479. "text/plain": [
  480. "0.8381527307120105"
  481. ]
  482. },
  483. "execution_count": 10,
  484. "metadata": {},
  485. "output_type": "execute_result"
  486. }
  487. ],
  488. "source": [
  489. "# test case 1\n",
  490. "# y_true_uni=[3, -1., 2, 7]\n",
  491. "# y_pred_uni=[2.1, 0.0, 2, 8]\n",
  492. "# expected rmse = 0.8381527307120105\n",
  493. "\n",
  494. "y_true = pd.read_csv(StringIO(\"\"\"\n",
  495. "d3mIndex,value\n",
  496. "1,3\n",
  497. "2,-1.0\n",
  498. "16,2\n",
  499. "17,7\n",
  500. "\"\"\"))\n",
  501. "y_pred = pd.read_csv(StringIO(\"\"\"\n",
  502. "d3mIndex,value\n",
  503. "1,2.1\n",
  504. "2,0.0\n",
  505. "16,2\n",
  506. "17,8\n",
  507. "\"\"\"))\n",
  508. "rootMeanSquaredError(y_true, y_pred)"
  509. ]
  510. },
  511. {
  512. "cell_type": "code",
  513. "execution_count": 11,
  514. "metadata": {},
  515. "outputs": [
  516. {
  517. "name": "stderr",
  518. "output_type": "stream",
  519. "text": [
  520. "/home/svattam/miniconda3/envs/automl/lib/python3.6/site-packages/ipykernel_launcher.py:32: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n",
  521. "/home/svattam/miniconda3/envs/automl/lib/python3.6/site-packages/ipykernel_launcher.py:33: FutureWarning: Method .as_matrix will be removed in a future version. Use .values instead.\n"
  522. ]
  523. },
  524. {
  525. "data": {
  526. "text/plain": [
  527. "0.8416254115301732"
  528. ]
  529. },
  530. "execution_count": 11,
  531. "metadata": {},
  532. "output_type": "execute_result"
  533. }
  534. ],
  535. "source": [
  536. "# test case 2\n",
  537. "# y_true_multi=[[0.5, 1],[-1, 1],[7, -6]]\n",
  538. "# y_pred_multi=[[0, 2],[-1, 2],[8, -5]]\n",
  539. "# expected rmse = 0.8416254115301732\n",
  540. "\n",
  541. "y_true = pd.read_csv(StringIO(\"\"\"\n",
  542. "d3mIndex,value1, value2\n",
  543. "1,0.5,1\n",
  544. "2,-1,1\n",
  545. "16,7,-6\n",
  546. "\"\"\"))\n",
  547. "y_pred = pd.read_csv(StringIO(\"\"\"\n",
  548. "d3mIndex,value1,value2\n",
  549. "1,0,2\n",
  550. "2,-1,2\n",
  551. "16,8,-5\n",
  552. "\"\"\"))\n",
  553. "rootMeanSquaredError(y_true, y_pred)"
  554. ]
  555. },
  556. {
  557. "cell_type": "code",
  558. "execution_count": null,
  559. "metadata": {},
  560. "outputs": [],
  561. "source": []
  562. },
  563. {
  564. "cell_type": "markdown",
  565. "metadata": {},
  566. "source": [
  567. "# Object detection average precision"
  568. ]
  569. },
  570. {
  571. "cell_type": "code",
  572. "execution_count": 12,
  573. "metadata": {},
  574. "outputs": [
  575. {
  576. "name": "stdout",
  577. "output_type": "stream",
  578. "text": [
  579. "TEST CASE 1 --- AP: 0.6666666666666666\n",
  580. "TEST CASE 2 --- AP: 0.125\n",
  581. "TEST CASE 3 --- AP: 0.4444444444444444\n",
  582. "TEST CASE 4 --- AP: 0.4444444444444444\n"
  583. ]
  584. }
  585. ],
  586. "source": [
  587. "def group_gt_boxes_by_image_name(gt_boxes):\n",
  588. " gt_dict: typing.Dict = {}\n",
  589. "\n",
  590. " for box in gt_boxes:\n",
  591. " image_name = box[0]\n",
  592. " bounding_polygon = box[1:]\n",
  593. " bbox = convert_bouding_polygon_to_box_coords(bounding_polygon)\n",
  594. "\n",
  595. " if image_name not in gt_dict.keys():\n",
  596. " gt_dict[image_name] = []\n",
  597. "\n",
  598. " gt_dict[image_name].append({'bbox': bbox})\n",
  599. "\n",
  600. " return gt_dict\n",
  601. "\n",
  602. "\n",
  603. "def convert_bouding_polygon_to_box_coords(bounding_polygon):\n",
  604. " # box_coords = [x_min, y_min, x_max, y_max]\n",
  605. " box_coords = [bounding_polygon[0], bounding_polygon[1],\n",
  606. " bounding_polygon[4], bounding_polygon[5]]\n",
  607. " return box_coords\n",
  608. "\n",
  609. "\n",
  610. "def voc_ap(rec, prec):\n",
  611. " import numpy\n",
  612. "\n",
  613. " # First append sentinel values at the end.\n",
  614. " mrec = numpy.concatenate(([0.], rec, [1.]))\n",
  615. " mpre = numpy.concatenate(([0.], prec, [0.]))\n",
  616. "\n",
  617. " # Compute the precision envelope.\n",
  618. " for i in range(mpre.size - 1, 0, -1):\n",
  619. " mpre[i - 1] = numpy.maximum(mpre[i - 1], mpre[i])\n",
  620. "\n",
  621. " # To calculate area under PR curve, look for points\n",
  622. " # where X axis (recall) changes value.\n",
  623. " i = numpy.where(mrec[1:] != mrec[:-1])[0]\n",
  624. "\n",
  625. " # And sum (\\Delta recall) * prec.\n",
  626. " ap = numpy.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])\n",
  627. "\n",
  628. " return float(ap)\n",
  629. "\n",
  630. "\n",
  631. "def object_detection_average_precision(y_true, y_pred):\n",
  632. " \"\"\"\n",
  633. " This function takes a list of ground truth bounding polygons (rectangles in this case)\n",
  634. " and a list of detected bounding polygons (also rectangles) for a given class and\n",
  635. " computes the average precision of the detections with respect to the ground truth polygons.\n",
  636. " Parameters:\n",
  637. " -----------\n",
  638. " y_true: list\n",
  639. " List of ground truth polygons. Each polygon is represented as a list of\n",
  640. " vertices, starting in the upper-left corner going counter-clockwise.\n",
  641. " Since in this case, the polygons are rectangles, they will have the\n",
  642. " following format:\n",
  643. " [image_name, x_min, y_min, x_min, y_max, x_max, y_max, x_max, y_min].\n",
  644. " y_pred: list\n",
  645. " List of bounding box polygons with their corresponding confidence scores. Each\n",
  646. " polygon is represented as a list of vertices, starting in the upper-left corner\n",
  647. " going counter-clockwise. Since in this case, the polygons are rectangles, they\n",
  648. " will have the following format:\n",
  649. " [image_name, x_min, y_min, x_min, y_max, x_max, y_max, x_max, y_min, confidence_score].\n",
  650. " Returns:\n",
  651. " --------\n",
  652. " ap: float\n",
  653. " Average precision between detected polygons (rectangles) and the ground truth polylgons (rectangles).\n",
  654. " (it is also the area under the precision-recall curve).\n",
  655. " Example 1:\n",
  656. " >> predictions_list_1 = [['img_00001.png', 110, 110, 110, 210, 210, 210, 210, 110, 0.6],\n",
  657. " ['img_00002.png', 5, 10, 5, 20, 20, 20, 20, 10, 0.9],\n",
  658. " ['img_00002.png', 120, 130, 120, 200, 200, 200, 200, 130, 0.6]]\n",
  659. " >> ground_truth_list_1 = [['img_00001.png', 100, 100, 100, 200, 200, 200, 200, 100],\n",
  660. " ['img_00002.png', 10, 10, 10, 20, 20, 20, 20, 10],\n",
  661. " ['img_00002.png', 70, 80, 70, 150, 140, 150, 140, 80]]\n",
  662. " >> ap_1 = object_detection_average_precision(ground_truth_list_1, predictions_list_1)\n",
  663. " >> print(ap_1)\n",
  664. " 0.667\n",
  665. " Example 2:\n",
  666. " >> predictions_list_2 = [['img_00285.png', 330, 463, 330, 505, 387, 505, 387, 463, 0.0739],\n",
  667. " ['img_00285.png', 420, 433, 420, 498, 451, 498, 451, 433, 0.0910],\n",
  668. " ['img_00285.png', 328, 465, 328, 540, 403, 540, 403, 465, 0.1008],\n",
  669. " ['img_00285.png', 480, 477, 480, 522, 508, 522, 508, 477, 0.1012],\n",
  670. " ['img_00285.png', 357, 460, 357, 537, 417, 537, 417, 460, 0.1058],\n",
  671. " ['img_00285.png', 356, 456, 356, 521, 391, 521, 391, 456, 0.0843],\n",
  672. " ['img_00225.png', 345, 460, 345, 547, 415, 547, 415, 460, 0.0539],\n",
  673. " ['img_00225.png', 381, 362, 381, 513, 455, 513, 455, 362, 0.0542],\n",
  674. " ['img_00225.png', 382, 366, 382, 422, 416, 422, 416, 366, 0.0559],\n",
  675. " ['img_00225.png', 730, 463, 730, 583, 763, 583, 763, 463, 0.0588]]\n",
  676. " >> ground_truth_list_2 = [['img_00285.png', 480, 457, 480, 529, 515, 529, 515, 457],\n",
  677. " ['img_00285.png', 480, 457, 480, 529, 515, 529, 515, 457],\n",
  678. " ['img_00225.png', 522, 540, 522, 660, 576, 660, 576, 540],\n",
  679. " ['img_00225.png', 739, 460, 739, 545, 768, 545, 768, 460]]\n",
  680. " >> ap_2 = object_detection_average_precision(ground_truth_list_2, predictions_list_2)\n",
  681. " >> print(ap_2)\n",
  682. " 0.125\n",
  683. " Example 3:\n",
  684. " >> predictions_list_3 = [['img_00001.png', 110, 110, 110, 210, 210, 210, 210, 110, 0.6],\n",
  685. " ['img_00002.png', 120, 130, 120, 200, 200, 200, 200, 130, 0.6],\n",
  686. " ['img_00002.png', 5, 8, 5, 16, 15, 16, 15, 8, 0.9],\n",
  687. " ['img_00002.png', 11, 12, 11, 18, 21, 18, 21, 12, 0.9]]\n",
  688. " >> ground_truth_list_3 = [['img_00001.png', 100, 100, 100, 200, 200, 200, 200, 100],\n",
  689. " ['img_00002.png', 10, 10, 10, 20, 20, 20, 20, 10],\n",
  690. " ['img_00002.png', 70, 80, 70, 150, 140, 150, 140, 80]]\n",
  691. " >> ap_3 = object_detection_average_precision(ground_truth_list_3, predictions_list_3)\n",
  692. " >> print(ap_3)\n",
  693. " 0.444\n",
  694. " Example 4:\n",
  695. " (Same as example 3 except the last two box predictions in img_00002.png are switched)\n",
  696. " >> predictions_list_4 = [['img_00001.png', 110, 110, 110, 210, 210, 210, 210, 110, 0.6],\n",
  697. " ['img_00002.png', 120, 130, 120, 200, 200, 200, 200, 130, 0.6],\n",
  698. " ['img_00002.png', 11, 12, 11, 18, 21, 18, 21, 12, 0.9],\n",
  699. " ['img_00002.png', 5, 8, 5, 16, 15, 16, 15, 8, 0.9]]\n",
  700. " >> ground_truth_list_4 = [['img_00001.png', 100, 100, 100, 200, 200, 200, 200, 100],\n",
  701. " ['img_00002.png', 10, 10, 10, 20, 20, 20, 20, 10],\n",
  702. " ['img_00002.png', 70, 80, 70, 150, 140, 150, 140, 80]]\n",
  703. " >> ap_4 = object_detection_average_precision(ground_truth_list_4, predictions_list_4)\n",
  704. " >> print(ap_4)\n",
  705. " 0.444\n",
  706. " \"\"\"\n",
  707. "\n",
  708. " \"\"\"\n",
  709. " This function is different from others because ``y_true`` and ``y_pred`` are not vectors but arrays.\n",
  710. " \"\"\"\n",
  711. " import numpy\n",
  712. " ovthresh = 0.5\n",
  713. "\n",
  714. " # y_true = typing.cast(Truth, unvectorize(y_true))\n",
  715. " # y_pred = typing.cast(Predictions, unvectorize(y_pred))\n",
  716. "\n",
  717. " # Load ground truth.\n",
  718. " gt_dict = group_gt_boxes_by_image_name(y_true)\n",
  719. "\n",
  720. " # Extract gt objects for this class.\n",
  721. " recs = {}\n",
  722. " npos = 0\n",
  723. "\n",
  724. " imagenames = sorted(gt_dict.keys())\n",
  725. " for imagename in imagenames:\n",
  726. " Rlist = [obj for obj in gt_dict[imagename]]\n",
  727. " bbox = numpy.array([x['bbox'] for x in Rlist])\n",
  728. " det = [False] * len(Rlist)\n",
  729. " npos = npos + len(Rlist)\n",
  730. " recs[imagename] = {'bbox': bbox, 'det': det}\n",
  731. "\n",
  732. " # Load detections.\n",
  733. " det_length = len(y_pred[0])\n",
  734. "\n",
  735. " # Check that all boxes are the same size.\n",
  736. " for det in y_pred:\n",
  737. " assert len(det) == det_length, 'Not all boxes have the same dimensions.'\n",
  738. "\n",
  739. " image_ids = [x[0] for x in y_pred]\n",
  740. " BP = numpy.array([[float(z) for z in x[1:-1]] for x in y_pred])\n",
  741. " BB = numpy.array([convert_bouding_polygon_to_box_coords(x) for x in BP])\n",
  742. "\n",
  743. " confidence = numpy.array([float(x[-1]) for x in y_pred])\n",
  744. " boxes_w_confidences_list = numpy.hstack((BB, -1 * confidence[:, None]))\n",
  745. " boxes_w_confidences = numpy.empty((boxes_w_confidences_list.shape[0],),\n",
  746. " dtype=[('x_min', float), ('y_min', float),\n",
  747. " ('x_max', float), ('y_max', float),\n",
  748. " ('confidence', float)])\n",
  749. " boxes_w_confidences[:] = [tuple(i) for i in boxes_w_confidences_list]\n",
  750. "\n",
  751. " # Sort by confidence.\n",
  752. " #sorted_ind = numpy.argsort(-confidence)\n",
  753. " sorted_ind = numpy.argsort(\n",
  754. " boxes_w_confidences, kind='mergesort',\n",
  755. " order=('confidence', 'x_min', 'y_min'))\n",
  756. " BB = BB[sorted_ind, :]\n",
  757. " image_ids = [image_ids[x] for x in sorted_ind]\n",
  758. "\n",
  759. " # Go down y_pred and mark TPs and FPs.\n",
  760. " nd = len(image_ids)\n",
  761. " tp = numpy.zeros(nd)\n",
  762. " fp = numpy.zeros(nd)\n",
  763. " for d in range(nd):\n",
  764. " R = recs[image_ids[d]]\n",
  765. " bb = BB[d, :].astype(float)\n",
  766. " ovmax = -numpy.inf\n",
  767. " BBGT = R['bbox'].astype(float)\n",
  768. "\n",
  769. " if BBGT.size > 0:\n",
  770. " # Compute overlaps.\n",
  771. " # Intersection.\n",
  772. " ixmin = numpy.maximum(BBGT[:, 0], bb[0])\n",
  773. " iymin = numpy.maximum(BBGT[:, 1], bb[1])\n",
  774. " ixmax = numpy.minimum(BBGT[:, 2], bb[2])\n",
  775. " iymax = numpy.minimum(BBGT[:, 3], bb[3])\n",
  776. " iw = numpy.maximum(ixmax - ixmin + 1., 0.)\n",
  777. " ih = numpy.maximum(iymax - iymin + 1., 0.)\n",
  778. " inters = iw * ih\n",
  779. "\n",
  780. " # Union.\n",
  781. " uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +\n",
  782. " (BBGT[:, 2] - BBGT[:, 0] + 1.) *\n",
  783. " (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)\n",
  784. "\n",
  785. " overlaps = inters / uni\n",
  786. " ovmax = numpy.max(overlaps)\n",
  787. " jmax = numpy.argmax(overlaps)\n",
  788. "\n",
  789. " if ovmax > ovthresh:\n",
  790. " if not R['det'][jmax]:\n",
  791. " tp[d] = 1.\n",
  792. " R['det'][jmax] = 1\n",
  793. " else:\n",
  794. " fp[d] = 1.\n",
  795. " else:\n",
  796. " fp[d] = 1.\n",
  797. "\n",
  798. " # Compute precision recall.\n",
  799. " fp = numpy.cumsum(fp)\n",
  800. " tp = numpy.cumsum(tp)\n",
  801. " rec = tp / float(npos)\n",
  802. " # Avoid divide by zero in case the first detection matches a difficult ground truth.\n",
  803. " prec = tp / numpy.maximum(tp + fp, numpy.finfo(numpy.float64).eps)\n",
  804. " ap = voc_ap(rec, prec)\n",
  805. "\n",
  806. " return ap\n",
  807. "\n",
  808. "\n",
  809. "if __name__ == \"__main__\":\n",
  810. " predictions_list_1 = [\n",
  811. " ['img_00001.png', 110, 110, 110, 210, 210, 210, 210, 110, 0.6],\n",
  812. " ['img_00002.png', 5, 10, 5, 20, 20, 20, 20, 10, 0.9],\n",
  813. " ['img_00002.png', 120, 130, 120, 200, 200, 200, 200, 130, 0.6]\n",
  814. " ]\n",
  815. " ground_truth_list_1 = [\n",
  816. " ['img_00001.png', 100, 100, 100, 200, 200, 200, 200, 100],\n",
  817. " ['img_00002.png', 10, 10, 10, 20, 20, 20, 20, 10],\n",
  818. " ['img_00002.png', 70, 80, 70, 150, 140, 150, 140, 80]\n",
  819. " ]\n",
  820. " ap_1 = object_detection_average_precision(\n",
  821. " ground_truth_list_1, predictions_list_1)\n",
  822. " print('TEST CASE 1 --- AP: ', ap_1)\n",
  823. "\n",
  824. " predictions_list_2 = [\n",
  825. " ['img_00285.png', 330, 463, 330, 505, 387, 505, 387, 463, 0.0739],\n",
  826. " ['img_00285.png', 420, 433, 420, 498, 451, 498, 451, 433, 0.0910],\n",
  827. " ['img_00285.png', 328, 465, 328, 540, 403, 540, 403, 465, 0.1008],\n",
  828. " ['img_00285.png', 480, 477, 480, 522, 508, 522, 508, 477, 0.1012],\n",
  829. " ['img_00285.png', 357, 460, 357, 537, 417, 537, 417, 460, 0.1058],\n",
  830. " ['img_00285.png', 356, 456, 356, 521, 391, 521, 391, 456, 0.0843],\n",
  831. " ['img_00225.png', 345, 460, 345, 547, 415, 547, 415, 460, 0.0539],\n",
  832. " ['img_00225.png', 381, 362, 381, 513, 455, 513, 455, 362, 0.0542],\n",
  833. " ['img_00225.png', 382, 366, 382, 422, 416, 422, 416, 366, 0.0559],\n",
  834. " ['img_00225.png', 730, 463, 730, 583, 763, 583, 763, 463, 0.0588],\n",
  835. " ]\n",
  836. " ground_truth_list_2 = [\n",
  837. " ['img_00285.png', 480, 457, 480, 529, 515, 529, 515, 457],\n",
  838. " ['img_00285.png', 480, 457, 480, 529, 515, 529, 515, 457],\n",
  839. " ['img_00225.png', 522, 540, 522, 660, 576, 660, 576, 540],\n",
  840. " ['img_00225.png', 739, 460, 739, 545, 768, 545, 768, 460],\n",
  841. " ]\n",
  842. " ap_2 = object_detection_average_precision(\n",
  843. " ground_truth_list_2, predictions_list_2)\n",
  844. " print('TEST CASE 2 --- AP: ', ap_2)\n",
  845. "\n",
  846. " predictions_list_3 = [\n",
  847. " ['img_00001.png', 110, 110, 110, 210, 210, 210, 210, 110, 0.6],\n",
  848. " ['img_00002.png', 120, 130, 120, 200, 200, 200, 200, 130, 0.6],\n",
  849. " ['img_00002.png', 5, 8, 5, 16, 15, 16, 15, 8, 0.9],\n",
  850. " ['img_00002.png', 11, 12, 11, 18, 21, 18, 21, 12, 0.9]\n",
  851. " ]\n",
  852. " ground_truth_list_3 = [\n",
  853. " ['img_00001.png', 100, 100, 100, 200, 200, 200, 200, 100],\n",
  854. " ['img_00002.png', 10, 10, 10, 20, 20, 20, 20, 10],\n",
  855. " ['img_00002.png', 70, 80, 70, 150, 140, 150, 140, 80]\n",
  856. " ]\n",
  857. " ap_3 = object_detection_average_precision(\n",
  858. " ground_truth_list_3, predictions_list_3)\n",
  859. " print('TEST CASE 3 --- AP: ', ap_3)\n",
  860. "\n",
  861. " predictions_list_4 = [\n",
  862. " ['img_00001.png', 110, 110, 110, 210, 210, 210, 210, 110, 0.6],\n",
  863. " ['img_00002.png', 120, 130, 120, 200, 200, 200, 200, 130, 0.6],\n",
  864. " ['img_00002.png', 11, 12, 11, 18, 21, 18, 21, 12, 0.9],\n",
  865. " ['img_00002.png', 5, 8, 5, 16, 15, 16, 15, 8, 0.9]\n",
  866. " ]\n",
  867. " ground_truth_list_4 = [\n",
  868. " ['img_00001.png', 100, 100, 100, 200, 200, 200, 200, 100],\n",
  869. " ['img_00002.png', 10, 10, 10, 20, 20, 20, 20, 10],\n",
  870. " ['img_00002.png', 70, 80, 70, 150, 140, 150, 140, 80]\n",
  871. " ]\n",
  872. " ap_4 = object_detection_average_precision(\n",
  873. " ground_truth_list_4, predictions_list_4)\n",
  874. " print('TEST CASE 4 --- AP: ', ap_4)"
  875. ]
  876. },
  877. {
  878. "cell_type": "code",
  879. "execution_count": null,
  880. "metadata": {},
  881. "outputs": [],
  882. "source": []
  883. }
  884. ],
  885. "metadata": {
  886. "kernelspec": {
  887. "display_name": "Python 3",
  888. "language": "python",
  889. "name": "python3"
  890. },
  891. "language_info": {
  892. "codemirror_mode": {
  893. "name": "ipython",
  894. "version": 3
  895. },
  896. "file_extension": ".py",
  897. "mimetype": "text/x-python",
  898. "name": "python",
  899. "nbconvert_exporter": "python",
  900. "pygments_lexer": "ipython3",
  901. "version": "3.6.5"
  902. }
  903. },
  904. "nbformat": 4,
  905. "nbformat_minor": 2
  906. }

全栈的自动化机器学习系统,主要针对多变量时间序列数据的异常检测。TODS提供了详尽的用于构建基于机器学习的异常检测系统的模块,它们包括:数据处理(data processing),时间序列处理( time series processing),特征分析(feature analysis),检测算法(detection algorithms),和强化模块( reinforcement module)。这些模块所提供的功能包括常见的数据预处理、时间序列数据的平滑或变换,从时域或频域中抽取特征、多种多样的检测算