{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "L4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GWRrNbFPKQYn", "outputId": "6b634b88-6309-4688-b83b-d4f3a7a1a9c4" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m719.8/719.8 kB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m542.0/542.0 kB\u001b[0m \u001b[31m11.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m9.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m401.2/401.2 kB\u001b[0m \u001b[31m14.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hMounted at /content/gdrive\n" ] } ], "source": [ "! [ -e /content ] && pip install -Uqq fastbook\n", "import fastbook\n", "fastbook.setup_book()\n", "from fastbook import *" ] }, { "cell_type": "code", "source": [ "from fastai.vision.all import *\n", "path = untar_data(URLs.PETS)\n", "files = get_image_files(path/\"images\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 37 }, "id": "f6-oTCT1K6G2", "outputId": "0543f378-053c-4a1e-a6e7-b65671e889b6" }, "execution_count": 2, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "
\n", " \n", " 100.00% [811712512/811706944 00:17<00:00]\n", "
\n", " " ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "class SiameseImage(fastuple):\n", " def show(self, ctx=None, **kwargs):\n", " img1,img2,same_breed = self\n", " if not isinstance(img1, Tensor):\n", " if img2.size != img1.size: img2 = img2.resize(img1.size)\n", " t1,t2 = tensor(img1),tensor(img2)\n", " t1,t2 = t1.permute(2,0,1),t2.permute(2,0,1)\n", " else: t1,t2 = img1,img2\n", " line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)\n", " return show_image(torch.cat([t1,line,t2], dim=2),\n", " title=same_breed, ctx=ctx)\n", "\n", "def label_func(fname):\n", " return re.match(r'^(.*)_\\d+.jpg$', fname.name).groups()[0]\n", "\n", "class SiameseTransform(Transform):\n", " def __init__(self, files, label_func, splits):\n", " self.labels = files.map(label_func).unique()\n", " self.lbl2files = {l: L(f for f in files if label_func(f) == l) for l in self.labels}\n", " self.label_func = label_func\n", " self.valid = {f: self._draw(f) for f in files[splits[1]]}\n", "\n", " def encodes(self, f):\n", " f2,t = self.valid.get(f, self._draw(f))\n", " img1,img2 = PILImage.create(f),PILImage.create(f2)\n", " return SiameseImage(img1, img2, t)\n", "\n", " def _draw(self, f):\n", " same = random.random() < 0.5\n", " cls = self.label_func(f)\n", " if not same: cls = random.choice(L(l for l in self.labels if l != cls))\n", " return random.choice(self.lbl2files[cls]),same\n", "\n", "splits = RandomSplitter()(files)\n", "tfm = SiameseTransform(files, label_func, splits)\n", "tls = TfmdLists(files, tfm, splits=splits)\n", "dls = tls.dataloaders(after_item=[Resize(224), ToTensor],\n", " after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])" ], "metadata": { "id": "mJ8AaB7RKg1t" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "class SiameseModel(Module):\n", " def __init__(self, encoder, head):\n", " self.encoder,self.head = encoder,head\n", "\n", " def forward(self, x1, x2):\n", " ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1)\n", " return self.head(ftrs)" ], "metadata": { "id": "M3JYas6-LkNg" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "encoder = create_body(resnet34(pretrained=True), cut=-2)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-Qz87zF6LwyU", "outputId": "30799eb8-1b39-451a-e6d8-c92c46af1deb" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.\n", " warnings.warn(msg)\n", "Downloading: \"https://download.pytorch.org/models/resnet34-b627a593.pth\" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth\n", "100%|██████████| 83.3M/83.3M [00:00<00:00, 173MB/s]\n" ] } ] }, { "cell_type": "code", "source": [ "head = create_head(512*2, 2, ps=0.5)" ], "metadata": { "id": "2sHLhnpVL5kM" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "model = SiameseModel(encoder, head)" ], "metadata": { "id": "H6CSklJ0NJXg" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "def loss_func(out, targ):\n", " return nn.CrossEntropyLoss()(out, targ.long())" ], "metadata": { "id": "ApQSbEFSNT_u" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "def siamese_splitter(model):\n", " return [params(model.encoder), params(model.head)]" ], "metadata": { "id": "egXC6h7dNY3W" }, "execution_count": 9, "outputs": [] }, { "cell_type": "code", "source": [ "learn = Learner(dls, model, loss_func=loss_func,\n", " splitter=siamese_splitter, metrics=accuracy)\n", "learn.freeze()" ], "metadata": { "id": "MZ1CztGSNjSV" }, "execution_count": 10, "outputs": [] }, { "cell_type": "code", "source": [ "learn.fit_one_cycle(4, 3e-3)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 230 }, "id": "YhzikHhbNqoj", "outputId": "1efa8c64-1eef-4f4e-aeb2-04c1d47d072b" }, "execution_count": 11, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.5451870.3579240.84912000:25
10.3678880.2483030.90054100:24
20.2939210.2007540.91677900:24
30.2407700.1883320.92489900:24
" ] }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", " self.pid = os.fork()\n" ] } ] }, { "cell_type": "code", "source": [ "learn.unfreeze()\n", "learn.fit_one_cycle(4, slice(1e-6,1e-4))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 175 }, "id": "oindz2cyNs1k", "outputId": "9ae7f627-1151-42d4-b39b-5e427ffb0145" }, "execution_count": 12, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.2281370.1766930.92760500:32
10.2239520.1700710.93234100:32
20.2183940.1694220.93437100:33
30.2269750.1685640.93504700:33
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "# 表格\n", "# emmm...这章没什么好看的了" ], "metadata": { "id": "UD_b48SPQNjK" }, "execution_count": 13, "outputs": [] } ] }