{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FOTs_vUOQC7_", "outputId": "a73236e6-09cf-4995-b0cc-2c37692ec4b7" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m719.8/719.8 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m542.0/542.0 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m401.2/401.2 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m19.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hMounted at /content/gdrive\n" ] } ], "source": [ "! [ -e /content ] && pip install -Uqq fastbook\n", "import fastbook\n", "fastbook.setup_book()" ] }, { "cell_type": "code", "source": [ "from fastbook import *" ], "metadata": { "id": "18LDcJi0QgBm" }, "execution_count": 2, "outputs": [] }, { "cell_type": "code", "source": [ "# 加载数据集\n", "from fastai.text.all import *\n", "path = untar_data(URLs.HUMAN_NUMBERS)" ], "metadata": { "id": "HCJz8WkZQ9Y1", "colab": { "base_uri": "https://localhost:8080/", "height": 37 }, "outputId": "8a973c5c-ee5c-4587-d948-17d75d1d019d" }, "execution_count": 3, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "
\n", " \n", " 108.32% [32768/30252 00:00<00:00]\n", "
\n", " " ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "#hide\n", "Path.BASE_PATH = path" ], "metadata": { "id": "87w-XbtbQ-UB" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "path.ls()" ], "metadata": { "id": "711tZIY8RlW2", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "5488f383-14a7-410a-e8df-0c7387260d06" }, "execution_count": 5, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#2) [Path('valid.txt'),Path('train.txt')]" ] }, "metadata": {}, "execution_count": 5 } ] }, { "cell_type": "code", "source": [ "lines = L()\n", "with open(path/'train.txt') as f: lines += L(*f.readlines())\n", "with open(path/'valid.txt') as f: lines += L(*f.readlines())\n", "lines" ], "metadata": { "id": "4aHDGY0KRn-g", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "67b150fc-4935-40f5-93ba-8a79ee503ad9" }, "execution_count": 6, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]" ] }, "metadata": {}, "execution_count": 6 } ] }, { "cell_type": "code", "source": [ "text = ' . '.join([l.strip() for l in lines])\n", "text[:200]" ], "metadata": { "id": "oaRE2cTJSIGe", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "05d044a7-732a-4751-c633-ddc0b1c84d90" }, "execution_count": 7, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fourteen . fifteen . sixteen . seventeen . eighteen . nineteen . twenty . twenty one . twenty two . tw'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 7 } ] }, { "cell_type": "code", "source": [ "tokens = text.split(' ')\n", "tokens[:10]" ], "metadata": { "id": "6-HW-hvyS44Q", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "5414b112-1831-4a4c-fe8f-0604d4de99e9" }, "execution_count": 8, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']" ] }, "metadata": {}, "execution_count": 8 } ] }, { "cell_type": "code", "source": [ "vocab = L(*tokens).unique()\n", "vocab" ], "metadata": { "id": "h4bCLQdiTMsz", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "25f395f0-02e3-4f4a-d107-fc15c78e8b7b" }, "execution_count": 9, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]" ] }, "metadata": {}, "execution_count": 9 } ] }, { "cell_type": "code", "source": [ "word2idx = {w:i for i,w in enumerate(vocab)}\n", "nums = L(word2idx[i] for i in tokens)\n", "nums" ], "metadata": { "id": "efUuwK-FU39O", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "4d649f3c-f089-4a85-87a9-e2e0b2e65f2c" }, "execution_count": 10, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#63095) [0,1,2,1,3,1,4,1,5,1...]" ] }, "metadata": {}, "execution_count": 10 } ] }, { "cell_type": "code", "source": [ "# 从零开始的语言模型\n", "# 三词序列的列表作为我们的自变量,每个序列后的下一个单词作为因变量\n", "L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))" ], "metadata": { "id": "WZNRkasUVQKw", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "996da6c3-e6ce-42c4-dd43-6d1589657486" }, "execution_count": 11, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]" ] }, "metadata": {}, "execution_count": 11 } ] }, { "cell_type": "code", "source": [ "# 数值化\n", "seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n", "seqs" ], "metadata": { "id": "MvMUn7g7BoNU", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "b3a160ac-3df8-4ce1-bfc4-5332bfaa6d77" }, "execution_count": 12, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]" ] }, "metadata": {}, "execution_count": 12 } ] }, { "cell_type": "code", "source": [ "# 批处理\n", "bs = 64\n", "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)" ], "metadata": { "id": "gjCoqBWyChtZ" }, "execution_count": 13, "outputs": [] }, { "cell_type": "code", "source": [ "class LMModel1(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.h_h = nn.Linear(n_hidden, n_hidden)\n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", "\n", " def forward(self, x):\n", " h = F.relu(self.h_h(self.i_h(x[:,0])))\n", " h = h + self.i_h(x[:,1])\n", " h = F.relu(self.h_h(h))\n", " h = h + self.i_h(x[:,2])\n", " h = F.relu(self.h_h(h))\n", " return self.h_o(h)" ], "metadata": { "id": "V5goxtruD3ms" }, "execution_count": 14, "outputs": [] }, { "cell_type": "code", "source": [ "learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy,\n", " metrics=accuracy)\n", "learn.fit_one_cycle(4, 1e-3)" ], "metadata": { "id": "fFFa6_czFJZ4", "colab": { "base_uri": "https://localhost:8080/", "height": 210 }, "outputId": "cb5427b9-ed48-46a4-f659-478bf60c7954" }, "execution_count": 15, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
01.8242971.9709410.46755400:02
11.3869731.8232420.46755400:02
21.4175561.6544980.49441400:06
31.3764401.6508490.49441400:07
" ] }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", " self.pid = os.fork()\n" ] } ] }, { "cell_type": "code", "source": [ "# 基线\n", "n,counts = 0,torch.zeros(len(vocab))\n", "for x,y in dls.valid:\n", " n += y.shape[0]\n", " for i in range_of(vocab): counts[i] += (y==i).long().sum()\n", "idx = torch.argmax(counts)\n", "idx, vocab[idx.item()], counts[idx].item()/n" ], "metadata": { "id": "7QT6h_9dFL85", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "957f10fa-cbe0-4875-ce9c-5862b913fabb" }, "execution_count": 16, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(tensor(29), 'thousand', 0.15165200855716662)" ] }, "metadata": {}, "execution_count": 16 } ] }, { "cell_type": "code", "source": [ "# 使用循环重构(只不过通过for循环代替调用层的重复代码)\n", "class LMModel2(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.h_h = nn.Linear(n_hidden, n_hidden)\n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", "\n", " def forward(self, x):\n", " h = 0\n", " for i in range(3):\n", " h = h + self.i_h(x[:,i])\n", " h = F.relu(self.h_h(h))\n", " return self.h_o(h)" ], "metadata": { "id": "M0qKNmhzFSOG" }, "execution_count": 17, "outputs": [] }, { "cell_type": "code", "source": [ "learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy,\n", " metrics=accuracy)\n", "learn.fit_one_cycle(4, 1e-3)" ], "metadata": { "id": "zunoCejQG2A7", "colab": { "base_uri": "https://localhost:8080/", "height": 175 }, "outputId": "4c09daa3-a4a0-42d5-c962-143561de26f4" }, "execution_count": 18, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
01.8162741.9641430.46018500:04
11.4238051.7399640.47325900:03
21.4303271.6851720.48538200:03
31.3883901.6570330.47040600:02
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "# 保持RNN的状态\n", "class LMModel3(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.h_h = nn.Linear(n_hidden, n_hidden)\n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", " self.h = 0\n", "\n", " def forward(self, x):\n", " for i in range(3):\n", " self.h = self.h + self.i_h(x[:,i])\n", " self.h = F.relu(self.h_h(self.h))\n", " out = self.h_o(self.h)\n", " self.h = self.h.detach()\n", " return out\n", "\n", " def reset(self): self.h = 0" ], "metadata": { "id": "kOu2IqjLHZZT" }, "execution_count": 19, "outputs": [] }, { "cell_type": "code", "source": [ "def group_chunks(ds, bs):\n", " m = len(ds) // bs\n", " new_ds = L()\n", " for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n", " return new_ds" ], "metadata": { "id": "Sgztje_pPNai" }, "execution_count": 20, "outputs": [] }, { "cell_type": "code", "source": [ "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(\n", " group_chunks(seqs[:cut], bs),\n", " group_chunks(seqs[cut:], bs),\n", " bs=bs, drop_last=True, shuffle=False)" ], "metadata": { "id": "S_h_UrgDPaaD" }, "execution_count": 21, "outputs": [] }, { "cell_type": "code", "source": [ "learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy,\n", " metrics=accuracy, cbs=ModelResetter)\n", "learn.fit_one_cycle(10, 3e-3)" ], "metadata": { "id": "UyeIHrkMPawu", "colab": { "base_uri": "https://localhost:8080/", "height": 363 }, "outputId": "638e8cb2-f08e-4fd6-b5e6-829e32795bbd" }, "execution_count": 22, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
01.6770741.8273670.46754800:02
11.2827221.8709130.38894200:02
21.0907051.6517940.46250000:02
31.0052151.6159900.51514400:03
40.9630201.6058940.55120200:03
50.9261711.7217250.54375000:02
60.9082321.6689490.55552900:02
70.8439801.7257720.57091300:02
80.8118981.7404540.58726000:02
90.7971761.7059230.58942300:03
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "# 创造更多信号\n", "sl = 16\n", "seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))\n", " for i in range(0,len(nums)-sl-1,sl))\n", "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),\n", " group_chunks(seqs[cut:], bs),\n", " bs=bs, drop_last=True, shuffle=False)" ], "metadata": { "id": "pl8qSubyPtbn" }, "execution_count": 23, "outputs": [] }, { "cell_type": "code", "source": [ "[L(vocab[o] for o in s) for s in seqs[0]]" ], "metadata": { "id": "0OvzvDI5RTHa", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "e255fd55-7770-427b-ed6a-abca23e31f40" }, "execution_count": 24, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[(#16) ['one','.','two','.','three','.','four','.','five','.'...],\n", " (#16) ['.','two','.','three','.','four','.','five','.','six'...]]" ] }, "metadata": {}, "execution_count": 24 } ] }, { "cell_type": "code", "source": [ "x = dls.one_batch()\n", "L(x[0][0,:], x[1][0,:])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HvrszlkHMIq_", "outputId": "e16670dd-3f8c-47a0-b2d0-d5f29fe12a3b" }, "execution_count": 25, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#2) [tensor([0, 1, 2, 1, 3, 1, 4, 1, 5, 1, 6, 1, 7, 1, 8, 1]),tensor([1, 2, 1, 3, 1, 4, 1, 5, 1, 6, 1, 7, 1, 8, 1, 9])]" ] }, "metadata": {}, "execution_count": 25 } ] }, { "cell_type": "code", "source": [ "print(x[0][:,1].shape)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nHr4bdQFQ7q0", "outputId": "e0b47926-e3c1-4f20-c30d-06f59e2d056b" }, "execution_count": 26, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([64])\n" ] } ] }, { "cell_type": "code", "source": [ "class LMModel4(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.h_h = nn.Linear(n_hidden, n_hidden)\n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", " self.h = 0\n", "\n", " def forward(self, x):\n", " outs = []\n", " for i in range(sl):\n", " self.h = self.h + self.i_h(x[:,i])\n", " self.h = F.relu(self.h_h(self.h))\n", " outs.append(self.h_o(self.h))\n", " self.h = self.h.detach()\n", " return torch.stack(outs, dim=1)\n", "\n", " def reset(self): self.h = 0" ], "metadata": { "id": "Gl6lcy0GRThl" }, "execution_count": 27, "outputs": [] }, { "cell_type": "code", "source": [ "model = LMModel4(len(vocab), 128)\n", "y = model.forward(x[0])\n", "print(y.shape)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dYzy7IY2QB8U", "outputId": "b7e3316f-348d-4690-efc3-021752bc25d9" }, "execution_count": 28, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([64, 16, 30])\n" ] } ] }, { "cell_type": "code", "source": [ "x=dls.one_batch()\n", "L(x[0][0], x[1][0])" ], "metadata": { "id": "GTKm9Bs4Rzo8", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "5a40c28e-a695-4d87-971e-19e754b332d5" }, "execution_count": 29, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#2) [tensor([0, 1, 2, 1, 3, 1, 4, 1, 5, 1, 6, 1, 7, 1, 8, 1]),tensor([1, 2, 1, 3, 1, 4, 1, 5, 1, 6, 1, 7, 1, 8, 1, 9])]" ] }, "metadata": {}, "execution_count": 29 } ] }, { "cell_type": "code", "source": [ "def loss_func(inp, targ):\n", " return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))" ], "metadata": { "id": "ii9J6syCRiIT" }, "execution_count": 30, "outputs": [] }, { "cell_type": "code", "source": [ "learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,\n", " metrics=accuracy, cbs=ModelResetter)\n", "learn.fit_one_cycle(15, 3e-3)" ], "metadata": { "id": "7j9iqiCURif_", "colab": { "base_uri": "https://localhost:8080/", "height": 520 }, "outputId": "5455abf1-eca5-4461-8409-0146a491901e" }, "execution_count": 31, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.2953393.2406630.12687200:01
12.3802382.0839450.46077500:01
21.7680411.8786190.46346000:01
31.4786321.7635350.49414100:01
41.2955321.8625800.48535200:00
51.1642311.8336990.53662100:01
61.0673471.9583520.53165700:00
70.9715342.2582210.54296900:01
80.8939032.1487980.54508500:01
90.8254952.2679680.57983400:00
100.7710202.3170990.57609100:01
110.7234322.4238530.61149100:01
120.6979102.4040980.59228500:01
130.6740402.4152670.60628300:01
140.6592272.4070860.60546900:01
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "# 多层循环神经网络\n", "class LMModel5(Module):\n", " def __init__(self, vocab_sz, n_hidden, n_layers):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h = torch.zeros(n_layers, bs, n_hidden)\n", "\n", " def forward(self, x):\n", " res,h = self.rnn(self.i_h(x), self.h)\n", " self.h = h.detach()\n", " return self.h_o(res)\n", "\n", " def reset(self): self.h.zero_()" ], "metadata": { "id": "CjZLCssDRkMs" }, "execution_count": 32, "outputs": [] }, { "cell_type": "code", "source": [ "learn = Learner(dls, LMModel5(len(vocab), 64, 2),\n", " loss_func=CrossEntropyLossFlat(),\n", " metrics=accuracy, cbs=ModelResetter)\n", "learn.fit_one_cycle(15, 3e-3)" ], "metadata": { "id": "q9VxUjYa8BBp", "colab": { "base_uri": "https://localhost:8080/", "height": 520 }, "outputId": "ba3cc9d5-0fc8-44a9-dc5e-8d5b9cddaad3" }, "execution_count": 33, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.0664202.6916520.45385700:01
12.1790081.8164060.47216800:01
21.7223501.9032720.32210300:01
31.5089821.7815230.45198600:01
41.3489091.6473420.49633800:01
51.2236091.6482800.50740600:01
61.1017101.6183880.53922500:01
70.9826601.6928690.56429000:01
80.8819911.7078520.56998700:01
90.8006311.8084960.55330400:01
100.7421271.8794270.55127000:01
110.7015781.9086330.55216500:01
120.6735221.8906790.55354800:01
130.6570151.8880500.55615200:01
140.6482901.8935090.55712900:02
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "# 长短期记忆\n", "# sigmoid: 0~1\n", "# tanh : -1~1\n", "\"\"\"\n", "感性的理解\n", "forget_gate: 遗忘门, 确定保留哪些信息和丢弃哪些信息:接近0的值被丢弃, 接近1的值被保留(sigmoid).\n", "input_gate : 输入门, 也是决定更新单元状态的哪些元素, 但它和第三个门单元门一起工作\n", "cell_gate : 单元门, 在输入门决定更新哪些元素之后, 它决定更新后的值是什么(tanh)\n", "output_gate: 输出门, 决定从单元状态中提取哪些信息来生成输出\n", "\"\"\"\n", "class LSTMCell(Module):\n", " def __init__(self, ni, nh):\n", " self.forget_gate = nn.Linear(ni+nh, nh)\n", " self.input_gate = nn.Linear(ni+nh, nh)\n", " self.cell_gate = nn.Linear(ni+nh, nh)\n", " self.output_gate = nn.Linear(ni+nh, nh)\n", "\n", " def forward(self, input, state):\n", " h, c = state\n", " h = torch.cat([h, input], dim=1)\n", " forget = torch.sigmoid(self.forget_gate(h))\n", " c = c*forget\n", "\n", " inp = torch.sigmoid(self.input_gate(h))\n", " cell = torch.tanh(self.cell_gate(h))\n", " c = c + inp * cell\n", "\n", " out = torch.sigmoid(self.output_gate(h))\n", " h = out * torch.tanh(c)\n", " return h, (h,c)" ], "metadata": { "id": "jm8CSOgx8CPE" }, "execution_count": 34, "outputs": [] }, { "cell_type": "code", "source": [ "# 重构只是为了加速计算\n", "# 感觉这个更好理解\n", "class LSTMCell(Module):\n", " def __init__(self, ni, nh):\n", " self.ih = nn.Linear(ni, 4*nh)\n", " self.hh = nn.Linear(nh, 4*nh)\n", "\n", " def forward(self, input, state):\n", " h, c = state\n", " gates = (self.ih(input) + self.hh(h).chunk(4, 1))\n", " ingate, forgetgate, outgate=map(torch.sigmoid, gates[:3])\n", " cellgate = gates[3].tanh()\n", "\n", " c=(forgetgate*c) + ingate*cellgate\n", " h=outgate*c.tanh()\n", " return h, (h,c)" ], "metadata": { "id": "PoLGkby6_8_0" }, "execution_count": 34, "outputs": [] }, { "cell_type": "code", "source": [ "# 双层LSTMCell\n", "class LMModel6(Module):\n", " def __init__(self, vocab_size, n_hidden, n_layers):\n", " self.i_h = nn.Embedding(vocab_size, n_hidden)\n", " self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n", " self.h_o = nn.Linear(n_hidden, vocab_size)\n", " self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]\n", "\n", " def forward(self, x):\n", " res, h = self.rnn(self.i_h(x), self.h)\n", " self.h = [h_.detach() for h_ in h]\n", " return self.h_o(res)\n", "\n", " def reset(self):\n", " for h in self.h: h.zero_()" ], "metadata": { "id": "akiEijYyRHzt" }, "execution_count": 35, "outputs": [] }, { "cell_type": "code", "source": [ "learn = Learner(dls, LMModel6(len(vocab), 64, 2),\n", " loss_func=CrossEntropyLossFlat(),\n", " metrics=accuracy, cbs=ModelResetter)\n", "learn.fit_one_cycle(15, 1e-2)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 520 }, "id": "it6Ioo93RLm2", "outputId": "4639badc-f060-4636-8349-d4afb27a9934" }, "execution_count": 36, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.0148462.6873590.34448200:02
12.1809512.3711350.27457700:02
21.6239121.9225790.46598300:02
31.3751931.8612920.50024400:01
41.1969902.1654520.52205400:01
51.0093021.9798210.54394500:01
60.7704451.8529260.63053400:01
70.4994691.7589140.68652300:02
80.3139691.6753740.68847700:02
90.2135591.6116570.74617500:02
100.1390521.5744600.73966500:01
110.0944631.5024960.76424200:01
120.0686291.4799710.76334600:01
130.0552851.4742560.77042600:01
140.0491521.4779550.76725300:01
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "class Dropout(Module):\n", " def __init__(self, p):\n", " self.p = p\n", " def forward(self, x):\n", " if not self.training: return x\n", " mask = x.new(*x.shape).bernoulli_(1-self.p)\n", " return x * mask.div_(1-self.p)" ], "metadata": { "id": "v7Cr1jzCRNcq" }, "execution_count": 37, "outputs": [] }, { "cell_type": "code", "source": [ "class LMModel7(Module):\n", " def __init__(self, vocab_sz, n_hidden, n_layers, p):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n", " self.drop = nn.Dropout(p)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h_o.weight = self.i_h.weight\n", " self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]\n", "\n", " def forward(self, x):\n", " raw,h = self.rnn(self.i_h(x), self.h)\n", " out = self.drop(raw)\n", " self.h = [h_.detach() for h_ in h]\n", " return self.h_o(out),raw,out\n", "\n", " def reset(self):\n", " for h in self.h: h.zero_()" ], "metadata": { "id": "uSbku_oTROty" }, "execution_count": 38, "outputs": [] }, { "cell_type": "code", "source": [ "learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),\n", " loss_func=CrossEntropyLossFlat(), metrics=accuracy)" ], "metadata": { "id": "4EjKabHCRP8_" }, "execution_count": 39, "outputs": [] }, { "cell_type": "code", "source": [ "learn.fit_one_cycle(15, 1e-2, wd=0.1)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 520 }, "id": "STsi_rmDRRY8", "outputId": "5e4c15f6-8e76-4b93-be31-b049b4b5ad20" }, "execution_count": 40, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
02.5782171.9065360.49373400:02
11.6168991.4144320.60091100:02
20.8868640.7622590.78922500:01
30.4462650.6993370.83365900:01
40.2310480.8198010.84179700:01
50.1285030.6878210.85913100:01
60.0770750.6595610.86043300:01
70.0514840.6981400.86539700:02
80.0378670.6322030.87280300:02
90.0296220.5390220.87890600:01
100.0238730.6630200.86499000:01
110.0205010.6191150.86995400:02
120.0167580.6281350.86572300:02
130.0143480.6218900.87011700:02
140.0132880.6325720.86694300:02
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "hx8itXSiRShU" }, "execution_count": null, "outputs": [] } ] }