{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": 9, "metadata": { "id": "t-JBCxXOoisb" }, "outputs": [], "source": [ "# 初始化\n", "! [ -e /content ] && pip install -Uqq fastbook\n", "import fastbook\n", "fastbook.setup_book()\n", "from fastbook import *" ] }, { "cell_type": "code", "source": [ "# 加载数据\n", "from fastai.collab import *\n", "from fastai.tabular.all import *\n", "path = untar_data(URLs.ML_100k)\n", "ratings = pd.read_csv(path/'u.data', delimiter='\\t', header=None,\n", " names=['user','movie','rating','timestamp'])\n", "ratings.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "BGMxL0-itWOb", "outputId": "53a483cf-eed5-412e-da2f-5e5b7a5ab8c0" }, "execution_count": 10, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " user movie rating timestamp\n", "0 196 242 3 881250949\n", "1 186 302 3 891717742\n", "2 22 377 1 878887116\n", "3 244 51 2 880606923\n", "4 166 346 1 886397596" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
usermovieratingtimestamp
01962423881250949
11863023891717742
2223771878887116
3244512880606923
41663461886397596
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "ratings", "summary": "{\n \"name\": \"ratings\",\n \"rows\": 100000,\n \"fields\": [\n {\n \"column\": \"user\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 266,\n \"min\": 1,\n \"max\": 943,\n \"num_unique_values\": 943,\n \"samples\": [\n 262,\n 136,\n 821\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"movie\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 330,\n \"min\": 1,\n \"max\": 1682,\n \"num_unique_values\": 1682,\n \"samples\": [\n 1557,\n 808,\n 1618\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"rating\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 1,\n \"max\": 5,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 5,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"timestamp\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 5343856,\n \"min\": 874724710,\n \"max\": 893286638,\n \"num_unique_values\": 49282,\n \"samples\": [\n 889728713,\n 888443306,\n 880605158\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 10 } ] }, { "cell_type": "code", "source": [ "movies = pd.read_csv(path/'u.item', delimiter='|', encoding='latin-1',\n", " usecols=(0,1), names=('movie','title'), header=None)\n", "movies.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "OTZRTuBctrIt", "outputId": "d5dc063a-cc6a-4d4a-f1e2-aab3fa41448b" }, "execution_count": 11, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " movie title\n", "0 1 Toy Story (1995)\n", "1 2 GoldenEye (1995)\n", "2 3 Four Rooms (1995)\n", "3 4 Get Shorty (1995)\n", "4 5 Copycat (1995)" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
movietitle
01Toy Story (1995)
12GoldenEye (1995)
23Four Rooms (1995)
34Get Shorty (1995)
45Copycat (1995)
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "movies", "summary": "{\n \"name\": \"movies\",\n \"rows\": 1682,\n \"fields\": [\n {\n \"column\": \"movie\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 485,\n \"min\": 1,\n \"max\": 1682,\n \"num_unique_values\": 1682,\n \"samples\": [\n 1394,\n 744,\n 1606\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"title\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1664,\n \"samples\": [\n \"Madame Butterfly (1995)\",\n \"Wrong Trousers, The (1993)\",\n \"Breaking the Waves (1996)\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 11 } ] }, { "cell_type": "code", "source": [ "ratings = ratings.merge(movies)\n", "ratings.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "FZNs_hA8wk0g", "outputId": "4aefc8a3-8173-41c7-d69a-30a240c44892" }, "execution_count": 12, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " user movie rating timestamp title\n", "0 196 242 3 881250949 Kolya (1996)\n", "1 63 242 3 875747190 Kolya (1996)\n", "2 226 242 5 883888671 Kolya (1996)\n", "3 154 242 3 879138235 Kolya (1996)\n", "4 306 242 5 876503793 Kolya (1996)" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
usermovieratingtimestamptitle
01962423881250949Kolya (1996)
1632423875747190Kolya (1996)
22262425883888671Kolya (1996)
31542423879138235Kolya (1996)
43062425876503793Kolya (1996)
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "ratings", "summary": "{\n \"name\": \"ratings\",\n \"rows\": 100000,\n \"fields\": [\n {\n \"column\": \"user\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 266,\n \"min\": 1,\n \"max\": 943,\n \"num_unique_values\": 943,\n \"samples\": [\n 574,\n 696,\n 434\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"movie\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 330,\n \"min\": 1,\n \"max\": 1682,\n \"num_unique_values\": 1682,\n \"samples\": [\n 1557,\n 808,\n 1618\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"rating\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 1,\n \"max\": 5,\n \"num_unique_values\": 5,\n \"samples\": [\n 5,\n 1,\n 4\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"timestamp\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 5343856,\n \"min\": 874724710,\n \"max\": 893286638,\n \"num_unique_values\": 49282,\n \"samples\": [\n 888540314,\n 887746686,\n 880888037\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"title\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 1664,\n \"samples\": [\n \"House Party 3 (1994)\",\n \"Three Colors: White (1994)\",\n \"Fish Called Wanda, A (1988)\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 12 } ] }, { "cell_type": "code", "source": [ "dls = CollabDataLoaders.from_df(ratings, item_name='title', bs=64)\n", "dls.show_batch()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 363 }, "id": "VBOZcf7kwqz7", "outputId": "7114f8d4-f47b-4812-ea71-29258754d2bc" }, "execution_count": 13, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
usertitlerating
0374Jungle Book, The (1994)5
1503Star Wars (1977)5
2343Emma (1996)4
3263Terminator, The (1984)5
44092001: A Space Odyssey (1968)5
5711That Thing You Do! (1996)4
6471Matilda (1996)5
7389North by Northwest (1959)5
8313Devil's Own, The (1997)3
9821Cry, the Beloved Country (1995)5
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "n_users = len(dls.classes['user'])\n", "n_movies = len(dls.classes['title'])\n", "n_factors = 5\n", "\n", "user_factors = torch.randn(n_users, n_factors)\n", "movie_factors = torch.randn(n_movies, n_factors)" ], "metadata": { "id": "mbsfoWglwxOj" }, "execution_count": 14, "outputs": [] }, { "cell_type": "code", "source": [ "# 介绍用一位有效编码向量替换我们的索引\n", "one_hot_3 = one_hot(3, n_users).float()" ], "metadata": { "id": "Jdwg6fnZw-MS" }, "execution_count": 15, "outputs": [] }, { "cell_type": "code", "source": [ "user_factors.t() @ one_hot_3" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "XHxs4M5FyBtW", "outputId": "64e74a0a-6026-45c1-ace9-e35567458d3d" }, "execution_count": 16, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([ 0.5017, 0.5266, -1.4344, -1.1046, -0.0247])" ] }, "metadata": {}, "execution_count": 16 } ] }, { "cell_type": "code", "source": [ "user_factors[3]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "cVXLiJozyEII", "outputId": "6747db07-eb63-4fda-ccec-953397230e02" }, "execution_count": 17, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([ 0.5017, 0.5266, -1.4344, -1.1046, -0.0247])" ] }, "metadata": {}, "execution_count": 17 } ] }, { "cell_type": "code", "source": [ "# 构建协同过滤的模型\n", "class DotProductBias(Module):\n", " def __init__(self, n_users, n_movies, n_factors, y_range=(0,5.5)):\n", " self.user_factors = Embedding(n_users, n_factors)\n", " self.user_bias = Embedding(n_users, 1)\n", " self.movie_factors = Embedding(n_movies, n_factors)\n", " self.movie_bias = Embedding(n_movies, 1)\n", " self.y_range = y_range\n", "\n", " def forward(self, x):\n", " users = self.user_factors(x[:,0])\n", " movies = self.movie_factors(x[:,1])\n", " res = (users * movies).sum(dim=1, keepdim=True)\n", " res += self.user_bias(x[:,0]) + self.movie_bias(x[:,1])\n", " return sigmoid_range(res, *self.y_range)" ], "metadata": { "id": "32zQqkir005A" }, "execution_count": 18, "outputs": [] }, { "cell_type": "code", "source": [ "# 一捆数据的形状\n", "x,y = dls.one_batch()\n", "x.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3b7HbNPDy3Hr", "outputId": "ce15cfc1-1b4c-46cd-fae8-78ae1399fa37" }, "execution_count": 19, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([64, 2])" ] }, "metadata": {}, "execution_count": 19 } ] }, { "cell_type": "code", "source": [ "# 构建学习器\n", "model = DotProductBias(n_users, n_movies, 50)\n", "learn = Learner(dls, model, loss_func=MSELossFlat())" ], "metadata": { "id": "BCl2oOmD0pov" }, "execution_count": 20, "outputs": [] }, { "cell_type": "code", "source": [ "# 训练\n", "learn.fit_one_cycle(5, 5e-3)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "4hxh7gMmFFNH", "outputId": "968d148b-d94e-4c60-df75-2137e991ba57" }, "execution_count": 21, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.9647720.93565000:16
10.8352710.86303800:13
20.6020940.86333100:13
30.4067350.88915600:13
40.2828560.89620000:12
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "# 引入了权值衰减\n", "learn = collab_learner(dls, n_factors=50, y_range=(0, 5.5))" ], "metadata": { "id": "CjresuQRFUSi" }, "execution_count": 22, "outputs": [] }, { "cell_type": "code", "source": [ "# 重新训练\n", "learn.fit_one_cycle(5, 5e-3, wd=0.1)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "HK5kFEKc_0Ri", "outputId": "7ab9e7a5-fcc2-46d3-f535-be8b20d9722a" }, "execution_count": 23, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.9507990.92872500:13
10.8515750.87413700:13
20.7202900.83411200:13
30.6073910.81743000:12
40.5055060.81831700:12
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "# 模型的结构\n", "learn.model" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ic4DsArD_5A3", "outputId": "1898562e-941b-440c-ac1c-deddc9ee1402" }, "execution_count": 24, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "EmbeddingDotBias(\n", " (u_weight): Embedding(944, 50)\n", " (i_weight): Embedding(1665, 50)\n", " (u_bias): Embedding(944, 1)\n", " (i_bias): Embedding(1665, 1)\n", ")" ] }, "metadata": {}, "execution_count": 24 } ] }, { "cell_type": "code", "source": [ "# 深度学习中的协同过滤\n", "class CollabNN(Module):\n", " def __init__(self, user_sz, item_sz, y_range=(0,5.5), n_act=100):\n", " self.user_factors = Embedding(*user_sz)\n", " self.item_factors = Embedding(*item_sz)\n", " self.layers = nn.Sequential(\n", " nn.Linear(user_sz[1]+item_sz[1], n_act),\n", " nn.ReLU(),\n", " nn.Linear(n_act, 1))\n", " self.y_range = y_range\n", "\n", " def forward(self, x):\n", " embs = self.user_factors(x[:,0]),self.item_factors(x[:,1])\n", " x = self.layers(torch.cat(embs, dim=1))\n", " return sigmoid_range(x, *self.y_range)" ], "metadata": { "id": "IV6cQyMZeQjh" }, "execution_count": 31, "outputs": [] }, { "cell_type": "code", "source": [ "embs = get_emb_sz(dls)\n", "model = CollabNN(*embs)" ], "metadata": { "id": "awL6OobOeTFW" }, "execution_count": 34, "outputs": [] }, { "cell_type": "code", "source": [ "learn = Learner(dls, model, loss_func=MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3, wd=0.01)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "bMPQPYbRcP4p", "outputId": "19265b83-afad-4721-e5d0-afa6015468ef" }, "execution_count": 35, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.9561420.95608600:17
10.8807690.90527700:17
20.8811230.87229000:15
30.8269010.86387200:15
40.7696030.86788400:15
" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "learn = collab_learner(dls, use_nn=True, y_range=(0, 5.5), layers=[100,50])\n", "learn.fit_one_cycle(5, 5e-3, wd=0.1)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "Aj1WDLZ0e4gE", "outputId": "07bb41f7-9789-4fc6-fea5-7a4cbcd189c3" }, "execution_count": 36, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.9849440.98233000:18
10.9571740.91339200:16
20.8834210.88146300:18
30.8297910.85310500:18
40.7693900.85780300:17
" ] }, "metadata": {} } ] } ] }