|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# fastNLP开发进阶教程\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 组织数据部分\n",
- "## DataSet & Instance\n",
- "fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 声明部件\n",
- "import torch\n",
- "import fastNLP\n",
- "from fastNLP import DataSet\n",
- "from fastNLP import Instance\n",
- "from fastNLP import Vocabulary\n",
- "from fastNLP import Trainer\n",
- "from fastNLP import Tester"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Instance\n",
- "Instance表示一个样本,由一个或者多个field(域、属性、特征)组成,每个field具有自己的名字以及值\n",
- "在初始化Instance的时候可以定义它包含的field,使用\"field_name=field_value\"的写法"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'premise': an premise example . type=str,\n",
- "'hypothesis': an hypothesis example. type=str,\n",
- "'label': 1 type=int}"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 组织一个Instance,这个Instance由premise、hypothesis、label三个field组成\n",
- "instance = Instance(premise='an premise example .', hypothesis='an hypothesis example.', label=1)\n",
- "instance"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "DataSet({'premise': an premise example . type=str,\n",
- "'hypothesis': an hypothesis example. type=str,\n",
- "'label': 1 type=int},\n",
- "{'premise': an premise example . type=str,\n",
- "'hypothesis': an hypothesis example. type=str,\n",
- "'label': 1 type=int})"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "data_set = DataSet([instance] * 5)\n",
- "data_set.append(instance)\n",
- "data_set[-2: ]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "DataSet({'premise': an premise example . type=str,\n",
- "'hypothesis': an hypothesis example. type=str,\n",
- "'label': 1 type=int},\n",
- "{'premise': the second premise example . type=str,\n",
- "'hypothesis': the second hypothesis example. type=str,\n",
- "'label': 1 type=str})"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 如果某一个field的类型与dataset对应的field类型不一样仍可被加入dataset中\n",
- "instance2 = Instance(premise='the second premise example .', hypothesis='the second hypothesis example.', label='1')\n",
- "try:\n",
- " data_set.append(instance2)\n",
- "except:\n",
- " pass\n",
- "data_set[-2: ]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "cannot append instance\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "DataSet({'premise': an premise example . type=str,\n",
- "'hypothesis': an hypothesis example. type=str,\n",
- "'label': 1 type=int},\n",
- "{'premise': the second premise example . type=str,\n",
- "'hypothesis': the second hypothesis example. type=str,\n",
- "'label': 1 type=str})"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 如果某一个field的名字不对,则该instance不能被append到dataset中\n",
- "instance3 = Instance(premises='the third premise example .', hypothesis='the third hypothesis example.', label=1)\n",
- "try:\n",
- " data_set.append(instance3)\n",
- "except:\n",
- " print('cannot append instance')\n",
- " pass\n",
- "data_set[-2: ]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "DataSet({'image': tensor([[ 0.3582, -1.0358, 1.4785, -1.5288, -0.9982],\n",
- " [-0.3973, -0.4294, 0.9215, -1.9631, -1.6556],\n",
- " [ 0.3313, -1.7714, 0.8729, 0.6976, -1.3172],\n",
- " [-0.6403, 0.5023, -0.9919, 1.1178, -0.3710],\n",
- " [-0.3692, 1.8631, -1.3646, -0.7290, -1.0774]]) type=torch.Tensor,\n",
- "'label': 0 type=int})"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 除了文本以外,还可以将tensor作为其中一个field的value\n",
- "import torch\n",
- "tensor_ins = Instance(image=torch.randn(5, 5), label=0)\n",
- "ds = DataSet()\n",
- "ds.append(tensor_ins)\n",
- "ds"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### DataSet\n",
- "### 使用现有代码读取并组织DataSet\n",
- "在DataSet类当中有一些read_* 方法,可以从文件中读取数据并组织DataSet"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "77"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import fastNLP\n",
- "from fastNLP import DataSet\n",
- "from fastNLP import Instance\n",
- "\n",
- "# 从csv读取数据到DataSet\n",
- "# 类csv文件,即每一行为一个example的文件,都可以使用这种方法进行数据读取\n",
- "dataset = DataSet.read_csv('tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), sep='\\t')\n",
- "# 查看DataSet的大小\n",
- "len(dataset)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n",
- "'label': 1 type=str}"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 使用数字索引[k],获取第k个样本\n",
- "dataset[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "fastNLP.core.instance.Instance"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 获取的样本是一个Instance\n",
- "type(dataset[0])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "DataSet({'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n",
- "'label': 1 type=str},\n",
- "{'raw_sentence': This quiet , introspective and entertaining independent is worth seeking . type=str,\n",
- "'label': 4 type=str},\n",
- "{'raw_sentence': Even fans of Ismail Merchant 's work , I suspect , would have a hard time sitting through this one . type=str,\n",
- "'label': 1 type=str})"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 使用数字索引[a: b],获取第a到第b个样本\n",
- "dataset[0: 3]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'raw_sentence': A film that clearly means to preach exclusively to the converted . type=str,\n",
- "'label': 2 type=str}"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 索引也可以是负数\n",
- "dataset[-1]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 自行读取并组织DataSet\n",
- "以SNLI数据集为例,\n",
- "SNLI数据集的训练、验证、测试集分别三个文件组成:第一个文件每一行是一句话,代表一个example当中的premise;第二个文件每一行也是一句话,代表一个example当中的hypothesis;第三个文件每一行是一个label"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'premise': A person on a horse jumps over a broken down airplane . type=str,\n",
- "'hypothesis': A person is training his horse for a competition . type=str,\n",
- "'truth': 1\n",
- " type=str}"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "data_path = ['premise', 'hypothesis', 'label']\n",
- "\n",
- "# 读入文件\n",
- "with open(data_path[0]) as f:\n",
- " premise = f.readlines()\n",
- "\n",
- "with open(data_path[1]) as f:\n",
- " hypothesis = f.readlines()\n",
- "\n",
- "with open(data_path[2]) as f:\n",
- " label = f.readlines()\n",
- "\n",
- "assert len(premise) == len(hypothesis) and len(hypothesis) == len(label)\n",
- "\n",
- "# 组织DataSet\n",
- "data_set = DataSet()\n",
- "for p, h, l in zip(premise, hypothesis, label):\n",
- " p = p.strip() # 将行末空格去除\n",
- " h = h.strip() # 将行末空格去除\n",
- " data_set.append(Instance(premise=p, hypothesis=h, truth=l))\n",
- "\n",
- "data_set[0]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### DataSet的其他操作\n",
- "在构建完毕DataSet后,仍然可以对DataSet的内容进行操作,函数接口为DataSet.apply()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "DataSet({'premise': a woman is walking across the street eating a banana , while a man is following with his briefcase . type=str,\n",
- "'hypothesis': An actress and her favorite assistant talk a walk in the city . type=str,\n",
- "'truth': 1\n",
- " type=str},\n",
- "{'premise': a woman is walking across the street eating a banana , while a man is following with his briefcase . type=str,\n",
- "'hypothesis': a woman eating a banana crosses a street type=str,\n",
- "'truth': 0\n",
- " type=str})"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 将premise域的所有文本转成小写\n",
- "data_set.apply(lambda x: x['premise'].lower(), new_field_name='premise')\n",
- "data_set[-2: ]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "DataSet({'premise': a woman is walking across the street eating a banana , while a man is following with his briefcase . type=str,\n",
- "'hypothesis': An actress and her favorite assistant talk a walk in the city . type=str,\n",
- "'truth': 1 type=int},\n",
- "{'premise': a woman is walking across the street eating a banana , while a man is following with his briefcase . type=str,\n",
- "'hypothesis': a woman eating a banana crosses a street type=str,\n",
- "'truth': 0 type=int})"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# label转int\n",
- "data_set.apply(lambda x: int(x['truth']), new_field_name='truth')\n",
- "data_set[-2: ]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "DataSet({'premise': ['a', 'woman', 'is', 'walking', 'across', 'the', 'street', 'eating', 'a', 'banana', ',', 'while', 'a', 'man', 'is', 'following', 'with', 'his', 'briefcase', '.'] type=list,\n",
- "'hypothesis': ['An', 'actress', 'and', 'her', 'favorite', 'assistant', 'talk', 'a', 'walk', 'in', 'the', 'city', '.'] type=list,\n",
- "'truth': 1 type=int},\n",
- "{'premise': ['a', 'woman', 'is', 'walking', 'across', 'the', 'street', 'eating', 'a', 'banana', ',', 'while', 'a', 'man', 'is', 'following', 'with', 'his', 'briefcase', '.'] type=list,\n",
- "'hypothesis': ['a', 'woman', 'eating', 'a', 'banana', 'crosses', 'a', 'street'] type=list,\n",
- "'truth': 0 type=int})"
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 使用空格分割句子\n",
- "def split_sent(ins):\n",
- " return ins['premise'].split()\n",
- "data_set.apply(split_sent, new_field_name='premise')\n",
- "data_set.apply(lambda x: x['hypothesis'].split(), new_field_name='hypothesis')\n",
- "data_set[-2:]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(100, 97)"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 筛选数据\n",
- "origin_data_set_len = len(data_set)\n",
- "data_set.drop(lambda x: len(x['premise']) <= 6)\n",
- "origin_data_set_len, len(data_set)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'premise': ['a', 'woman', 'is', 'walking', 'across', 'the', 'street', 'eating', 'a', 'banana', ',', 'while', 'a', 'man', 'is', 'following', 'with', 'his', 'briefcase', '.'] type=list,\n",
- "'hypothesis': ['a', 'woman', 'eating', 'a', 'banana', 'crosses', 'a', 'street'] type=list,\n",
- "'truth': 0 type=int,\n",
- "'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
- "'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1] type=list}"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 增加长度信息\n",
- "data_set.apply(lambda x: [1] * len(x['premise']), new_field_name='premise_len')\n",
- "data_set.apply(lambda x: [1] * len(x['hypothesis']), new_field_name='hypothesis_len')\n",
- "data_set[-1]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 设定特征域、标签域\n",
- "data_set.rename_field(\"premise\",\"words1\")\n",
- "data_set.rename_field(\"premise_len\",\"seq_len1\")\n",
- "data_set.rename_field(\"hypothesis\",\"words2\")\n",
- "data_set.rename_field(\"hypothesis_len\",\"seq_len2\")\n",
- "data_set.set_input(\"words1\", \"seq_len1\", \"words2\", \"seq_len2\")\n",
- "data_set.set_target(\"truth\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'words1': ['a', 'woman', 'is', 'walking', 'across', 'the', 'street', 'eating', 'a', 'banana', ',', 'while', 'a', 'man', 'is', 'following', 'with', 'his', 'briefcase', '.'] type=list,\n",
- "'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
- "'words2': ['a', 'woman', 'eating', 'a', 'banana', 'crosses', 'a', 'street'] type=list,\n",
- "'seq_len2': [1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
- "'label': 0 type=int}"
- ]
- },
- "execution_count": 19,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 重命名field\n",
- "data_set.rename_field('truth', 'label')\n",
- "data_set[-1]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(49, 29, 19)"
- ]
- },
- "execution_count": 20,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 切分训练、验证集、测试集\n",
- "train_data, vad_data = data_set.split(0.5)\n",
- "dev_data, test_data = vad_data.split(0.4)\n",
- "len(train_data), len(dev_data), len(test_data)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 深拷贝一个数据集\n",
- "import copy\n",
- "train_data_2, dev_data_2 = copy.deepcopy(train_data), copy.deepcopy(dev_data)\n",
- "del copy"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### DataSet的总结:\n",
- "将DataSet的ield设置为input和target后,这些field在接下来的代码中将被使用。其中被设置为input的field会被传递给<strong>Model.forward</strong>,这个过程是通过键匹配的方式进行的。举例如下: \n",
- "假设DataSet中有'x1', 'x2', 'x3'被设置为了input,而 \n",
- "   (1)函数是Model.forward(self, x1, x3), 那么DataSet中'x1', 'x3'会被传递给forward函数。多余的'x2'会被忽略 \n",
- "   (2)函数是Model.forward(self, x1, x4), 这里多需要了一个'x4', 但是DataSet的input field中没有这个field,会报错。 \n",
- "   (3)函数是Model.forward(self, x1, kwargs), 会把'x1', 'x2', 'x3'都传入。但如果是Model.forward(self, x4, kwargs)就会发生报错,因为没有'x4'。 \n",
- "   (4)对于设置为target的field的名称,我们建议取名为'target'(如果只有一个需要predict的值),但是不强制。如果这个名称不是target,那么在加载loss函数的时候需要手动添加名称转换map"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Vocabulary\n",
- "fastNLP中的Vocabulary轻松构建词表,并将词转成数字。构建词表有两种方式:根据数据集构建词表;载入现有词表\n",
- "### 根据数据集构建词表"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 初始化词表,该词表最大的vocab_size为10000,词表中每个词出现的最低频率为2,'<unk>'表示未知词语,'<pad>'表示padding词语\n",
- "# Vocabulary默认初始化参数为max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'\n",
- "vocab = Vocabulary(max_size=10000, min_freq=2, unknown='<unk>', padding='<pad>')\n",
- "\n",
- "# 构建词表\n",
- "train_data.apply(lambda x: [vocab.add(word) for word in x['words1']])\n",
- "train_data.apply(lambda x: [vocab.add(word) for word in x['words2']])\n",
- "vocab.build_vocab()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "({'words1': [2, 9, 4, 2, 75, 85, 7, 86, 76, 77, 87, 88, 89, 2, 90, 3] type=list,\n",
- " 'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
- " 'words2': [18, 9, 10, 1, 3] type=list,\n",
- " 'seq_len2': [1, 1, 1, 1, 1] type=list,\n",
- " 'label': 1 type=int},\n",
- " {'words1': [22, 32, 5, 110, 81, 111, 112, 5, 82, 3] type=list,\n",
- " 'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
- " 'words2': [64, 32, 82, 133, 84, 3] type=list,\n",
- " 'seq_len2': [1, 1, 1, 1, 1, 1] type=list,\n",
- " 'label': 0 type=int},\n",
- " {'words1': [2, 9, 97, 1, 20, 7, 54, 5, 1, 1, 70, 2, 11, 110, 2, 62, 3] type=list,\n",
- " 'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
- " 'words2': [23, 1, 58, 10, 12, 1, 70, 133, 84, 3] type=list,\n",
- " 'seq_len2': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
- " 'label': 1 type=int})"
- ]
- },
- "execution_count": 24,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 根据词表index句子\n",
- "train_data.apply(lambda x: [vocab.to_index(word) for word in x['words1']], new_field_name='words1')\n",
- "train_data.apply(lambda x: [vocab.to_index(word) for word in x['words2']], new_field_name='words2')\n",
- "dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words1']], new_field_name='words1')\n",
- "dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words2']], new_field_name='words2')\n",
- "test_data.apply(lambda x: [vocab.to_index(word) for word in x['words1']], new_field_name='words1')\n",
- "test_data.apply(lambda x: [vocab.to_index(word) for word in x['words2']], new_field_name='words2')\n",
- "train_data[-1], dev_data[-1], test_data[-1]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 载入现有词表\n",
- "以BERT pretrained model为例,词表由一个vocab.txt文件来保存\n",
- "用以下方法可以载入现有词表,并保证词表顺序不变"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 读入vocab文件\n",
- "with open('vocab.txt') as f:\n",
- " lines = f.readlines()\n",
- "vocabs = []\n",
- "for line in lines:\n",
- " vocabs.append(line.strip())\n",
- "\n",
- "# 实例化Vocabulary\n",
- "vocab_bert = Vocabulary(unknown=None, padding=None)\n",
- "# 将vocabs列表加入Vocabulary\n",
- "vocab_bert.add_word_lst(vocabs)\n",
- "# 构建词表\n",
- "vocab_bert.build_vocab()\n",
- "# 更新unknown与padding的token文本\n",
- "vocab_bert.unknown = '[UNK]'\n",
- "vocab_bert.padding = '[PAD]'"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "({'words1': [1037, 2450, 1999, 1037, 2665, 6598, 1998, 7415, 2058, 2014, 2132, 2559, 2875, 1037, 3028, 1012] type=list,\n",
- " 'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
- " 'words2': [100, 2450, 2003, 3147, 1012] type=list,\n",
- " 'seq_len2': [1, 1, 1, 1, 1] type=list,\n",
- " 'label': 1 type=int},\n",
- " {'words1': [2048, 2308, 1010, 3173, 2833, 100, 16143, 1010, 8549, 1012] type=list,\n",
- " 'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
- " 'words2': [100, 2308, 8549, 2169, 2060, 1012] type=list,\n",
- " 'seq_len2': [1, 1, 1, 1, 1, 1] type=list,\n",
- " 'label': 0 type=int})"
- ]
- },
- "execution_count": 26,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 根据词表index句子\n",
- "train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['words1']], new_field_name='words1')\n",
- "train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['words2']], new_field_name='words2')\n",
- "dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['words1']], new_field_name='words1')\n",
- "dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['words2']], new_field_name='words2')\n",
- "train_data_2[-1], dev_data_2[-1]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 模型部分\n",
- "## Model\n",
- "模型部分fastNLP提供两种使用方式:调用fastNLP现有模型;开发者自行搭建模型\n",
- "### 调用fastNLP现有模型"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'embed_dim': 300,\n",
- " 'hidden_size': 300,\n",
- " 'batch_first': True,\n",
- " 'dropout': 0.3,\n",
- " 'num_classes': 3,\n",
- " 'gpu': True,\n",
- " 'batch_size': 32,\n",
- " 'vocab_size': 143}"
- ]
- },
- "execution_count": 27,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# step 1:加载模型参数(非必选)\n",
- "from fastNLP.io.config_io import ConfigSection, ConfigLoader\n",
- "args = ConfigSection()\n",
- "ConfigLoader().load_config(\"./data/config\", {\"esim_model\": args})\n",
- "args[\"vocab_size\"] = len(vocab)\n",
- "args.data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "ESIM(\n",
- " (drop): Dropout(p=0.3)\n",
- " (embedding): Embedding(\n",
- " 143, 300\n",
- " (dropout): Dropout(p=0.3)\n",
- " )\n",
- " (embedding_layer): Linear(in_features=300, out_features=300, bias=True)\n",
- " (encoder): LSTM(\n",
- " (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)\n",
- " )\n",
- " (bi_attention): BiAttention()\n",
- " (mean_pooling): MeanPoolWithMask()\n",
- " (max_pooling): MaxPoolWithMask()\n",
- " (inference_layer): Linear(in_features=1200, out_features=300, bias=True)\n",
- " (decoder): LSTM(\n",
- " (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)\n",
- " )\n",
- " (output): MLP(\n",
- " (hiddens): ModuleList(\n",
- " (0): Linear(in_features=1200, out_features=300, bias=True)\n",
- " )\n",
- " (output): Linear(in_features=300, out_features=3, bias=True)\n",
- " (dropout): Dropout(p=0.3)\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 28,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# step 2:加载ESIM模型\n",
- "from fastNLP.models import ESIM\n",
- "model = ESIM(args[\"vocab_size\"], args[\"embed_dim\"], args[\"hidden_size\"], args[\"dropout\"], args[\"num_classes\"])\n",
- "model"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "这是上述模型的forward方法。如果你不知道什么是forward方法,请参考我们的PyTorch教程。\n",
- "\n",
- "注意两点:\n",
- "1. forward参数名字叫**word_seq**,请记住。\n",
- "2. forward的返回值是一个**dict**,其中有个key的名字叫**pred**。\n",
- "\n",
- "```Python\n",
- " def forward(self, word_seq):\n",
- " \"\"\"\n",
- "\n",
- " :param word_seq: torch.LongTensor, [batch_size, seq_len]\n",
- " :return output: dict of torch.LongTensor, [batch_size, num_classes]\n",
- " \"\"\"\n",
- " x = self.embed(word_seq) # [N,L] -> [N,L,C]\n",
- " x = self.conv_pool(x) # [N,L,C] -> [N,C]\n",
- " x = self.dropout(x)\n",
- " x = self.fc(x) # [N,C] -> [N, N_class]\n",
- " return {'pred': x}\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 自行搭载模型\n",
- "自行搭载的模型必须是nn.Module的子类, \n",
- "(1)必须实现forward方法,并且forward方法不能出现**\\*arg**这种参数,例如\n",
- "```Python\n",
- " def forword(self, word_seq, *args): # 这是不允许的\n",
- " xxx\n",
- "```\n",
- "forward函数的返回值必须是一个**dict**。 \n",
- "dict当中模型预测的值所对应的key建议用**'pred'**,这里不做强制限制,但是如果不是pred的话,在加载loss函数的时候需要手动添加名称转换map \n",
- "(2)如果实现了predict方法,在做evaluation的时候将调用predict方法而不是forward。如果没有predict方法,则在evaluation时调用forward方法。predict方法也不能使用\\*args这种参数形式,同时结果也必须返回一个dict,同样推荐key为'pred'。 \n",
- "(3)forward函数可以计算loss并返回结果,在dict中的key为'loss',如: \n",
- "```Python\n",
- " def forword(self, word_seq, *args): \n",
- " xxx\n",
- " return {'pred': pred, 'loss': loss}\n",
- "```\n",
- "当loss函数没有在trainer里被定义的时候,trainer将会根据forward函数返回的dict中key为'loss'的值来进行反向传播,具体见loss部分"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 训练测试部分\n",
- "## Loss\n",
- "### 键映射\n",
- "根据上文DataSet与Model部分可以知道,fastNLP并不限制Model.forward()的返回值,也不限制DataSet中target field的key。因此在计算loss的时候,需要通过键映射的方式来完成取值。 \n",
- "这里以CrossEntropyLoss为例,我们的交叉熵函数部分如下:\n",
- "```Python\n",
- " def get_loss(self, pred, target):\n",
- " return F.cross_entropy(input=pred, target=target,\n",
- " ignore_index=self.padding_idx)\n",
- "```\n",
- "这里接收的两个参数名字分别为pred和target,其中pred是从model的forward函数返回值中取得,target是从DataSet的is_target的field当中取得。在没有设置键映射的基础上,pred从model的forward函数返回的dict中取'pred'键得到;target从DataSet的'target'field中得到。\n",
- "。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 修改键映射\n",
- "在初始化CrossEntropyLoss的时候,可以传入两个参数(pred=None, target=None), 这两个参数接受的类型是str,假设(pred='output', target='label'),那么CrossEntropyLoss会使用'output'这个key在forward的output与batch_y中寻找值;'label'也是在forward的output与d ataset的is target field中寻找值。注意这里pred或target的来源并不一定非要来自于model.forward与dataset的is target field,也可以只来自于forward的结果"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 创建一个自己的loss\n",
- "    (1)采用fastNLP.LossInForward,在model.forward()的结果中包含一个'loss'的key,具体见**自行搭载模型**部分 \n",
- "    (2)自己定义一个继承于fastNLP.core.loss.LossBase的class。重写get_loss方法。 \n",
- "      (2.1)在初始化自己的loss class的时候,需要初始化需要映射的值 \n",
- "      (2.2)在get_loss函数中,参数的名字需要与初始化时的映射是一致的 \n",
- "以L1Loss为例子:\n",
- "```Python\n",
- "class L1Loss(LossBase):\n",
- " def __init__(self, pred=None, target=None):\n",
- " super(L1Loss, self).__init__()\n",
- " \"\"\"\n",
- " 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于\n",
- " \"键映射\"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,\n",
- " 则\\不要将threshold传入_init_param_map.\n",
- " \"\"\"\n",
- " self._init_param_map(pred=pred, target=target)\n",
- "\n",
- " def get_loss(self, pred, target):\n",
- " # 这里'pred', 'target'必须和初始化的映射是一致的。\n",
- " return F.l1_loss(input=pred, target=target)\n",
- "```\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Trainer\n",
- "trainer的作用是训练模型,是一个对训练过程的封装。trainer当中比较重要的函数是trainer.train(),train函数的主要步骤包括: \n",
- "(1)创建batch \n",
- "```Python\n",
- "batch = Batch(dataset, batch_size, sampler=sampler)\n",
- "```\n",
- "(2)for batch_x, batch_y in batch: (batch_x, batch_y的内容分别为dataset中is input和is target的部分,这两个dict的key就是DataSet中的key,value会根据情况做好padding及tensor) \n",
- "  (2.1)将batch_x, batch_y中的tensor移动到model所在的device \n",
- "  (2.2)根据model.forward的参数列表,从batch_x中取出需要传递给forward的数据 \n",
- "  (2.3)获取model.forward返回的dict,并与batch_y一起传递给loss函数,求得loss \n",
- "  (2.4)对loss进行反向梯度传播并更新参数 \n",
- "(3)如果有验证集,则进行验证 \n",
- "(4)如果验证集的结果是当前最佳结果,则保存模型"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "除了以上的内容, Trainer中还提供了\"预跑\"的功能。该功能通过check_code_level管理,如果check_code_level为-1,则不进行\"预跑\"。 check_code_level=0,1,2代表不同的提醒级别。目前不同提醒级别对应的是对DataSet中设置为input或target但又没有使用的field的提醒级别。 0是忽略(默认);1是会warning发生了未使用field的情况;2是出现了unused会直接报错并退出运行 \"预跑\"的主要目的有两个: \n",
- "(1)防止train完了之后进行evaluation的时候出现错误。之前的train就白费了 \n",
- "(2)由于存在\"键映射\",直接运行导致的报错可能不太容易debug,通过\"预跑\"过程的报错会有一些debug提示 \"预跑\"会进行以下的操作: \n",
- "  (i) 使用很小的batch_size, 检查batch_x中是否包含Model.forward所需要的参数。只会运行两个循环。 \n",
- "  (ii)将Model.foward的输出pred_dict与batch_y输入到loss中,并尝试backward。不会更新参数,而且grad会被清零。如果传入了dev_data,还将进行metric的测试 \n",
- "  (iii)创建Tester,并传入少量数据,检测是否可以正常运行 \n",
- "\"预跑\"操作是在Trainer初始化的时候执行的。正常情况下,应该不需要改动\"预跑\"的代码。但如果遇到bug或者有什么好的建议,欢迎在开发群或者github提交issue。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "training epochs started 2019-05-14-19-49-25\n"
- ]
- },
- {
- "ename": "AssertionError",
- "evalue": "seq_len can only have one dimension, got False.",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-29-a3d2740dc8ef>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0muse_tqdm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m )\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/fastNLP/core/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, load_best_model)\u001b[0m\n\u001b[1;32m 522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_manager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_train_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 524\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 525\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_manager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_train_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mCallbackException\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/fastNLP/core/trainer.py\u001b[0m in \u001b[0;36m_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[0;31m# negative sampling; replace unknown; re-weight batch_y\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 574\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_manager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_y\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 575\u001b[0;31m \u001b[0mprediction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 576\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 577\u001b[0m \u001b[0;31m# edit prediction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/fastNLP/core/trainer.py\u001b[0m in \u001b[0;36m_data_forward\u001b[0;34m(self, network, x)\u001b[0m\n\u001b[1;32m 661\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_data_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnetwork\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 662\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_build_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetwork\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 663\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnetwork\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 664\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 665\u001b[0m raise TypeError(\n",
- "\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 491\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 492\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/fastNLP/models/snli.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, words1, words2, seq_len1, seq_len2, target)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mseq_len1\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0mseq_len1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseq_len_to_mask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq_len1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0mseq_len1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpremise0\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpremise0\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/fastNLP/core/utils.py\u001b[0m in \u001b[0;36mseq_len_to_mask\u001b[0;34m(seq_len)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 628\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mseq_len\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf\"seq_len can only have one dimension, got {seq_len.dim() == 1}.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 629\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseq_len\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 630\u001b[0m \u001b[0mmax_len\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseq_len\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mAssertionError\u001b[0m: seq_len can only have one dimension, got False."
- ]
- }
- ],
- "source": [
- "from fastNLP import CrossEntropyLoss\n",
- "from fastNLP import Adam\n",
- "from fastNLP import AccuracyMetric\n",
- "trainer = Trainer(\n",
- " train_data=train_data,\n",
- " model=model,\n",
- " loss=CrossEntropyLoss(pred='pred', target='label'), # 模型预测值通过'pred'来取得,目标值(ground truth)由'label'取得\n",
- " metrics=AccuracyMetric(target='label'), # 目标值(ground truth)由'label'取得\n",
- " n_epochs=5,\n",
- " batch_size=16,\n",
- " print_every=-1,\n",
- " validate_every=-1,\n",
- " dev_data=dev_data,\n",
- " optimizer=Adam(lr=1e-3, weight_decay=0),\n",
- " check_code_level=-1,\n",
- " metric_key='acc',\n",
- " use_tqdm=False,\n",
- ")\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Tester\n",
- "tester的作用是训练模型,是一个对训练过程的封装。tester当中比较重要的函数是tester.test(),test函数的主要步骤包括:\n",
- "(1)创建batch\n",
- "```Python\n",
- "batch = Batch(dataset, batch_size, sampler=sampler)\n",
- "```\n",
- "(2)for batch_x, batch_y in batch: (batch_x, batch_y的内容分别为dataset中is input和is target的部分,这两个dict的key就是DataSet中的key,value会根据情况做好padding及tensor) \n",
- "  (2.1)同步数据与model将batch_x, batch_y中的tensor移动到model所在的device \n",
- "  (2.2)根据predict_func的参数列表,从batch_x中取出需要传递给predict_func的数据,得到结果pred_dict \n",
- "  (2.3)调用metric(pred_dict, batch_y) \n",
- "  (2.4)当所有batch都运行完毕,会调用metric的get_metric方法,并且以返回的值作为evaluation的结果 "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[tester] \n",
- "AccuracyMetric: acc=0.368421\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'AccuracyMetric': {'acc': 0.368421}}"
- ]
- },
- "execution_count": 30,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "tester = Tester(\n",
- " data=test_data,\n",
- " model=model,\n",
- " metrics=AccuracyMetric(target='label'),\n",
- " batch_size=args[\"batch_size\"],\n",
- ")\n",
- "tester.test()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.7"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|