{ "cells": [ { "cell_type": "markdown", "id": "eb9ee8d4-afdd-400a-bae0-03ca972a2559", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "# Natural Language Processing Tutorial 1 - Static Word Embeddings with Word2Vec\n", "" ] }, { "cell_type": "markdown", "id": "71725826-1cd3-4a89-9471-509ccb926af9", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "## What you'll learn in this tutorial\n", "\n", "- We'll use the [gensim](https://radimrehurek.com/gensim/index.html) library to:\n", " - explore pretrained word embeddings\n", " - pretrain our own embeddings\n", "- We will additionally:\n", " - visualize word embeddings\n", " - evaluate them intrisically and extrinsically" ] }, { "cell_type": "markdown", "id": "ad4b05ea-4690-4368-a195-a82cbe8382d6", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "## Our schedule for today\n", "\n", "- Part 1: Using pretrained word embeddings with gensim\n", " - How to download already pretrained embeddings\n", " - Nearest neighbour similarity search \n", " - Word embedding visualization via PCA\n", " - Intrisic evaluation with word analogy and word similarity benchmarks\n", " - **Task 1**\n", "- Part 2: Pretraining your **own** embeddings\n", " - Training choices\n", " - Saving and loading your embeddings\n", "- Part 3: Extrinsic evaluation of word embeddings\n", " - Using word2vec embeddings for spam classification\n", " - **Task 2**" ] }, { "cell_type": "markdown", "id": "53576457-01be-4ec5-9404-9fdff84906cb", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "## Part 1 : Using pretrained embeddings with gensim\n", "\n", "" ] }, { "cell_type": "markdown", "id": "263da382-d2e4-41c3-b002-ee540da2fbab", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### What is gensim?\n", "\n", "- Gensim is one of many core NLP libraries:\n", " - together with [NLTK](https://www.nltk.org), [spaCy](https://spacy.io) and [HuggingFace 🤗](https://huggingface.co)\n", " - you can find its documentation [here](https://radimrehurek.com/gensim/auto_examples/index.html#other-resources)\n", "- It can be used to deal with corpora and perform:\n", " - Retrieval\n", " - Topic Modelling\n", " - Representation Learning (**word2vec** and **doc2vec**)" ] }, { "cell_type": "code", "execution_count": 1, "id": "1ebdec4e", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [], "source": [ "# Run this cell now!\n", "import gensim\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import gensim.downloader as api\n", "from gensim import utils\n", "from gensim.models import KeyedVectors\n", "from gensim.test.utils import datapath\n", "from gensim.models import Word2Vec\n", "\n", "from sklearn.metrics.pairwise import cosine_similarity\n", "from sklearn.decomposition import PCA\n", "from sklearn.metrics import classification_report\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.linear_model import LogisticRegressionCV\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", "from scipy.stats import pearsonr, spearmanr\n", "\n", "import nltk\n", "from nltk.corpus import stopwords\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "import plotly.express as px" ] }, { "cell_type": "markdown", "id": "ee260e2b-1403-4be2-af68-86f63869a5b1", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Let's download some embeddings!" ] }, { "cell_type": "code", "execution_count": 2, "id": "a2ce6ee4-c4c5-4782-b404-c15f3ba0b1ad", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Run this cell now!\n", "word_emb = api.load('word2vec-google-news-300')" ] }, { "cell_type": "markdown", "id": "941bbbf7-0af0-43b9-994f-43186b2c0bf5", "metadata": {}, "source": [ "- The object that we get is of type [KeyedVectors](https://radimrehurek.com/gensim/models/keyedvectors.html)\n", "- This is simply a map $w \\rightarrow \\mathbf{e}_w \\in \\mathbb{R}^{300}$\n", "- You can explore [here](https://github.com/RaRe-Technologies/gensim-data#models) all the possible models or simply run ```api.info()```" ] }, { "cell_type": "markdown", "id": "bf5fe072-b0ee-4e3c-81ac-9220e3d17352", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### How do these embeddings look like?" ] }, { "cell_type": "code", "execution_count": 3, "id": "a7261e0d-a31a-4011-92e9-1dfeeb39c034", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(300,)\n", "[-0.06445312 -0.16015625 -0.01208496 0.13476562 -0.22949219 0.16210938\n", " 0.3046875 -0.1796875 -0.12109375 0.25390625 -0.01428223 -0.06396484\n", " -0.08056641 -0.05688477 -0.19628906 0.2890625 -0.05151367 0.14257812\n", " -0.10498047 -0.04736328 -0.34765625 0.35742188 0.265625 0.00188446\n", " -0.01586914 0.00195312 -0.35546875 0.22167969 0.05761719 0.15917969\n", " 0.08691406 -0.0267334 -0.04785156 0.23925781 -0.05981445 0.0378418\n", " 0.17382812 -0.41796875 0.2890625 0.32617188 0.02429199 -0.01647949\n", " -0.06494141 -0.08886719 0.07666016 -0.15136719 0.05249023 -0.04199219\n", " -0.05419922 0.00108337 -0.20117188 0.12304688 0.09228516 0.10449219\n", " -0.00408936 -0.04199219 0.01409912 -0.02111816 -0.13476562 -0.24316406\n", " 0.16015625 -0.06689453 -0.08984375 -0.07177734 -0.00595093 -0.00482178\n", " -0.00089264 -0.30664062 -0.0625 0.07958984 -0.00909424 -0.04492188\n", " 0.09960938 -0.33398438 -0.3984375 0.05541992 -0.06689453 -0.04467773\n", " 0.11767578 -0.13964844 -0.26367188 0.17480469 -0.17382812 -0.40625\n", " -0.06738281 -0.07617188 0.09423828 0.20996094 -0.16308594 -0.08691406\n", " -0.0534668 -0.10351562 -0.07617188 -0.11083984 -0.03515625 -0.14941406\n", " 0.0378418 0.38671875 0.14160156 -0.2890625 -0.16894531 -0.140625\n", " -0.04174805 0.22753906 0.24023438 -0.01599121 -0.06787109 0.21875\n", " -0.42382812 -0.5625 -0.49414062 -0.3359375 0.13378906 0.01141357\n", " 0.13671875 0.0324707 0.06835938 -0.27539062 -0.15917969 0.00121307\n", " 0.01208496 -0.0039978 0.00442505 -0.04541016 0.08642578 0.09960938\n", " -0.04296875 -0.11328125 0.13867188 0.41796875 -0.28320312 -0.07373047\n", " -0.11425781 0.08691406 -0.02148438 0.328125 -0.07373047 -0.01348877\n", " 0.17773438 -0.02624512 0.13378906 -0.11132812 -0.12792969 -0.12792969\n", " 0.18945312 -0.13867188 0.29882812 -0.07714844 -0.37695312 -0.10351562\n", " 0.16992188 -0.10742188 -0.29882812 0.00866699 -0.27734375 -0.20996094\n", " -0.1796875 -0.19628906 -0.22167969 0.08886719 -0.27734375 -0.13964844\n", " 0.15917969 0.03637695 0.03320312 -0.08105469 0.25390625 -0.08691406\n", " -0.21289062 -0.18945312 -0.22363281 0.06542969 -0.16601562 0.08837891\n", " -0.359375 -0.09863281 0.35546875 -0.00741577 0.19042969 0.16992188\n", " -0.06005859 -0.20605469 0.08105469 0.12988281 -0.01135254 0.33203125\n", " -0.08691406 0.27539062 -0.03271484 0.12011719 -0.0625 0.1953125\n", " -0.10986328 -0.11767578 0.20996094 0.19921875 0.02954102 -0.16015625\n", " 0.00276184 -0.01367188 0.03442383 -0.19335938 0.00352478 -0.06542969\n", " -0.05566406 0.09423828 0.29296875 0.04052734 -0.09326172 -0.10107422\n", " -0.27539062 0.04394531 -0.07275391 0.13867188 0.02380371 0.13085938\n", " 0.00236511 -0.2265625 0.34765625 0.13574219 0.05224609 0.18164062\n", " 0.0402832 0.23730469 -0.16992188 0.10058594 0.03833008 0.10839844\n", " -0.05615234 -0.00946045 0.14550781 -0.30078125 -0.32226562 0.18847656\n", " -0.40234375 -0.3125 -0.08007812 -0.26757812 0.16699219 0.07324219\n", " 0.06347656 0.06591797 0.17285156 -0.17773438 0.00276184 -0.05761719\n", " -0.2265625 -0.19628906 0.09667969 0.13769531 -0.49414062 -0.27929688\n", " 0.12304688 -0.30078125 0.01293945 -0.1875 -0.20898438 -0.1796875\n", " -0.16015625 -0.03295898 0.00976562 0.25390625 -0.25195312 0.00210571\n", " 0.04296875 0.01184082 -0.20605469 0.24804688 -0.203125 -0.17773438\n", " 0.07275391 0.04541016 0.21679688 -0.2109375 0.14550781 -0.16210938\n", " 0.20410156 -0.19628906 -0.35742188 0.35742188 -0.11962891 0.35742188\n", " 0.10351562 0.07080078 -0.24707031 -0.10449219 -0.19238281 0.1484375\n", " 0.00057983 0.296875 -0.12695312 -0.03979492 0.13183594 -0.16601562\n", " 0.125 0.05126953 -0.14941406 0.13671875 -0.02075195 0.34375 ]\n" ] } ], "source": [ "# Access embeddings with word-lookup\n", "print(word_emb[\"apple\"].shape)\n", "print(word_emb[\"apple\"])" ] }, { "cell_type": "code", "execution_count": 4, "id": "52a63f19-ac08-40a4-99c5-86b76bab8393", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 2.60009766e-02 -1.89208984e-03 1.85546875e-01 -5.17578125e-02\n", " 5.12695312e-03 -1.09863281e-01 -8.17871094e-03 -8.83789062e-02\n", " 9.66796875e-02 4.83398438e-02 1.10473633e-02 -3.63281250e-01\n", " 8.20312500e-02 -2.12402344e-02 1.58203125e-01 4.41894531e-02\n", " -1.17797852e-02 2.12890625e-01 -5.73730469e-02 5.66406250e-02\n", " -1.07421875e-01 1.85546875e-01 7.71484375e-02 1.44958496e-04\n", " 1.52343750e-01 -6.54296875e-02 -1.52343750e-01 2.25585938e-01\n", " 8.10546875e-02 8.88671875e-02 7.32421875e-02 -1.03515625e-01\n", " -6.68945312e-02 1.76757812e-01 2.12890625e-01 1.40625000e-01\n", " -3.41796875e-02 1.78222656e-02 5.95703125e-02 2.86102295e-04\n", " 5.88378906e-02 9.27734375e-03 1.66992188e-01 -2.70080566e-03\n", " 1.15722656e-01 1.04492188e-01 5.37109375e-02 1.85546875e-02\n", " 1.06445312e-01 5.05371094e-02 -1.64794922e-02 -1.27929688e-01\n", " 2.16796875e-01 5.15136719e-02 4.78515625e-02 1.52343750e-01\n", " 1.71875000e-01 7.86132812e-02 -5.88378906e-02 -4.29687500e-02\n", " -7.27539062e-02 1.81640625e-01 -8.05664062e-02 -1.54296875e-01\n", " -1.16699219e-01 8.44726562e-02 -6.17675781e-02 -4.51660156e-02\n", " 9.21630859e-03 1.33789062e-01 1.92871094e-02 6.44531250e-02\n", " 1.08886719e-01 1.58203125e-01 -2.35595703e-02 1.23535156e-01\n", " 1.69921875e-01 3.49121094e-02 1.29882812e-01 2.65625000e-01\n", " 1.93359375e-01 -8.83789062e-02 8.49609375e-02 -2.96630859e-02\n", " 5.76171875e-02 2.51464844e-02 -1.01562500e-01 1.99218750e-01\n", " 1.04492188e-01 -2.42919922e-02 2.01416016e-02 -3.51562500e-02\n", " 6.64062500e-02 -6.20117188e-02 2.90527344e-02 -9.81445312e-02\n", " -1.81640625e-01 2.14843750e-01 -5.76171875e-02 -4.51660156e-02\n", " 4.49218750e-02 -1.95312500e-02 -2.08984375e-01 1.19628906e-01\n", " -9.03320312e-02 5.07812500e-02 9.03320312e-03 -9.76562500e-02\n", " -7.86132812e-02 -1.36718750e-01 -1.13769531e-01 -5.64575195e-03\n", " -4.07714844e-02 -2.05993652e-03 -5.66406250e-02 3.64685059e-03\n", " 8.30078125e-02 -7.08007812e-02 2.63671875e-01 1.24511719e-01\n", " -1.61132812e-02 9.13085938e-02 -2.39257812e-01 -1.04980469e-02\n", " -6.78710938e-02 1.40625000e-01 2.34375000e-01 -6.39648438e-02\n", " 1.95312500e-01 5.02929688e-02 -1.25000000e-01 2.06298828e-02\n", " -1.19140625e-01 -1.17187500e-01 -9.01222229e-05 3.68652344e-02\n", " 1.46484375e-01 2.47802734e-02 -1.49414062e-01 3.03649902e-03\n", " -3.10058594e-02 1.06933594e-01 2.55859375e-01 -6.00585938e-02\n", " -2.07031250e-01 1.58203125e-01 -2.15820312e-01 -1.84570312e-01\n", " -1.72851562e-01 7.99560547e-03 -3.03955078e-02 9.81445312e-02\n", " 4.66918945e-03 2.57812500e-01 1.06933594e-01 1.26953125e-01\n", " 6.34765625e-02 -1.30859375e-01 6.54296875e-02 -9.91210938e-02\n", " 5.90820312e-02 -3.71093750e-02 1.01074219e-01 1.53320312e-01\n", " -1.53320312e-01 -7.56835938e-02 5.85937500e-02 -5.05371094e-02\n", " 2.08007812e-01 4.85839844e-02 -9.42382812e-02 -9.71679688e-02\n", " -1.23046875e-01 -1.97265625e-01 -1.76757812e-01 -1.11328125e-01\n", " 1.11328125e-01 -5.88378906e-02 2.27539062e-01 4.00390625e-02\n", " 1.24511719e-01 1.47460938e-01 1.81884766e-02 4.05273438e-02\n", " 1.69921875e-01 1.13769531e-01 -2.24609375e-02 6.73828125e-02\n", " 8.59375000e-02 6.73828125e-02 2.06298828e-02 4.78515625e-02\n", " 1.84326172e-02 2.05078125e-01 -4.68750000e-02 2.00195312e-01\n", " -1.56250000e-02 -1.40625000e-01 1.09863281e-02 -1.73828125e-01\n", " 4.85839844e-02 -1.58203125e-01 -1.04492188e-01 3.63769531e-02\n", " 3.01513672e-02 1.27929688e-01 -1.14257812e-01 1.41601562e-01\n", " 2.34375000e-01 -8.98437500e-02 -1.02996826e-03 -1.50390625e-01\n", " 1.79687500e-01 1.35742188e-01 -2.08007812e-01 -1.27563477e-02\n", " 1.75781250e-01 -1.39648438e-01 -2.03125000e-01 -3.00292969e-02\n", " -2.78320312e-02 -6.50024414e-03 1.26953125e-01 -1.49414062e-01\n", " 1.46484375e-01 -8.42285156e-03 1.12304688e-01 1.66015625e-01\n", " -1.57470703e-02 1.23046875e-01 7.22656250e-02 -4.37011719e-02\n", " -7.56835938e-02 -9.03320312e-02 1.01562500e-01 -1.44531250e-01\n", " -4.00390625e-02 -1.26953125e-02 2.66113281e-02 -7.81250000e-02\n", " 3.56445312e-02 3.49121094e-02 1.79687500e-01 -1.38671875e-01\n", " 2.80761719e-02 -2.86865234e-02 6.78710938e-02 7.03125000e-02\n", " 9.57031250e-02 5.00488281e-02 -2.20947266e-02 -3.00781250e-01\n", " 1.14257812e-01 7.51953125e-02 1.26342773e-02 1.32812500e-01\n", " 2.52685547e-02 3.63769531e-02 -2.81982422e-02 -1.36718750e-01\n", " 1.79687500e-01 -9.27734375e-02 8.49609375e-02 1.32812500e-01\n", " 3.97949219e-02 4.29687500e-01 -1.87988281e-02 -1.47460938e-01\n", " 6.10351562e-02 9.03320312e-02 8.69140625e-02 -6.88476562e-02\n", " 1.10839844e-01 9.81445312e-02 1.50390625e-01 1.61132812e-01\n", " -8.05664062e-02 -1.74804688e-01 -3.32031250e-02 -1.28906250e-01\n", " 1.22558594e-01 -1.44653320e-02 -1.63085938e-01 -3.58886719e-02\n", " 2.78320312e-02 -6.34765625e-02 -7.91015625e-02 -1.14746094e-01\n", " 1.84326172e-02 2.91748047e-02 -3.00781250e-01 -4.58984375e-02\n", " -1.74804688e-01 2.33398438e-01 2.25830078e-02 1.10351562e-01\n", " -1.03515625e-01 -1.21582031e-01 2.21679688e-01 -2.19726562e-02]\n" ] } ], "source": [ "# Access embeddings with index-lookup\n", "print(word_emb[10])" ] }, { "cell_type": "markdown", "id": "4c980361-d38a-4ef6-bb66-06e0c62c0a8d", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Let's check the vocabulary" ] }, { "cell_type": "markdown", "id": "3a19e89d-d6c2-49c5-8a8d-d71bd3561a67", "metadata": {}, "source": [ "- Two important attributes:\n", " - ```key_to_index``` : maps a word to its vocabulary index\n", " - ```index_to_key``` : maps a vocabulary index to corresponding word" ] }, { "cell_type": "code", "execution_count": 5, "id": "fc7eaeac-f570-453c-b9a0-9be2cbf4f05e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Vocabulary length 3000000\n", "Index of cat 5947\n", "Word at position 5947 cat\n" ] } ], "source": [ "print(f\"Vocabulary length {len(word_emb.key_to_index)}\")\n", "print(f\"Index of cat {word_emb.key_to_index['cat']}\") # from word to index\n", "print(f\"Word at position 5947 {word_emb.index_to_key[5947]}\") # from index to word" ] }, { "cell_type": "markdown", "id": "335ab44e-e5b9-4060-bc1b-8cc12ab8878c", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Compute similarity and distance" ] }, { "cell_type": "code", "execution_count": 6, "id": "18c9ff38-21b9-469b-9942-ef71fc1236d7", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "w1 w2 cos_sim cos_dist\n", "car minivan 0.691 0.309\n", "car bicycle 0.536 0.464\n", "car airplane 0.424 0.576\n", "car cereal 0.139 0.861\n", "car communism 0.058 0.942\n" ] } ], "source": [ "pairs = [\n", " ('car', 'minivan'), \n", " ('car', 'bicycle'), \n", " ('car', 'airplane'), \n", " ('car', 'cereal'), \n", " ('car', 'communism'),\n", "]\n", "print(\"w1 w2 cos_sim cos_dist\")\n", "for w1, w2 in pairs:\n", " print(f\"{w1} {w2} {word_emb.similarity(w1, w2):.3f} {word_emb.distance(w1, w2):.3f}\")\n", " " ] }, { "cell_type": "markdown", "id": "4968ae71-8c50-4f9e-9c09-24147eba5120", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Nearest Neighbour (NN) Retrieval // Similarity Search" ] }, { "cell_type": "markdown", "id": "7252e54f-35d1-424e-8bc2-a1e5d1b1452d", "metadata": {}, "source": [ "- gensim has a ```most_similar``` function:\n", " - however, it does not perform exhaustive nearest-neighbour research\n", " - given a query word $w_q$ we want to find a ranked list $L_q$ of words in vocabulary $V$\n", " in decreasing order of cosine similarity\n", " - e.g. $w_q$ = \"joy\" then $L_q$ = [\"joy\", \"happiness\",... ]\n", "- We can write our own function!" ] }, { "cell_type": "code", "execution_count": 7, "id": "44ef69c2-0055-486b-8137-7a9cdfc1b38d", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [], "source": [ "def retrieve_most_similar(query_words, all_word_emb, restrict_vocab=10000):\n", " \n", " # Step 1: Get full or restricted vocabulary embeddings\n", " # If restrict_vocab=None then we have exhaustive search, otherwise we restrict the vocab to the most frequent words\n", " vocab_emb = all_word_emb.vectors[:restrict_vocab+1,:] if restrict_vocab is not None else all_word_emb.vectors # shape: |V_r| x word_emb_size\n", " \n", " # Step 2: get the word embeddings for the query words\n", " query_emb = all_word_emb[query_words] # shape: |Q| x word_emb_size\n", " \n", " # Step 3: get cosine similarity between queries and embeddings\n", " cos_sim = cosine_similarity(query_emb, vocab_emb) # shape: |Q| x |V_r|\n", " \n", " # Step 4: sort similarities in desceding orders and get indices of nearest neighbours\n", " nn = np.argsort(-cos_sim) # shape: |Q| x |V_r|\n", " \n", " # Step 5: delete self-similarity, i.e. cos_sim(w,w)=1.0 \n", " nn_filtered = nn[:, 1:] # remove self_similarity\n", " \n", " # Step 6: use the indices to get the words\n", " nn_words = np.array(word_emb.index_to_key)[nn_filtered]\n", " \n", " return nn_words" ] }, { "cell_type": "code", "execution_count": 8, "id": "85f691a6-1c91-4546-84be-4ff2439e9c77", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[['kings' 'queen' 'monarch' 'crown_prince' 'prince' 'sultan' 'ruler'\n", " 'princes' 'throne' 'royal']\n", " ['queens' 'princess' 'king' 'monarch' 'Queen' 'princesses' 'royal'\n", " 'prince' 'duchess' 'Queen_Elizabeth_II']\n", " ['french' 'Italy' 'i' 'haha' 'Cagliari' 'india' 'dont' 'thats' 'mr'\n", " 'lol']\n", " ['Italian' 'Sicily' 'Italians' 'ITALY' 'Spain' 'Bologna' 'Italia'\n", " 'France' 'Milan' 'Romania']\n", " ['registered_nurse' 'nurses' 'nurse_practitioner' 'midwife' 'Nurse'\n", " 'nursing' 'doctor' 'medic' 'pharmacist' 'paramedic']]\n" ] } ], "source": [ "queries = [\"king\", \"queen\", \"italy\", \"Italy\", \"nurse\"]\n", "res = retrieve_most_similar(queries, word_emb, restrict_vocab=100000)\n", "top_k = 10\n", "res_k = res[:, :top_k]\n", "del res\n", "print(res_k)" ] }, { "cell_type": "markdown", "id": "bf58b008-445d-4950-a78d-4ab87fdba266", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Dimensionality Reduction and Plotting\n", "\n", "- We want to plot our word embeddings\n", "- But they ''live'' in $\\mathbb{R}^{300}$\n", "- Let's use dimensionality reduction techniques, like PCA" ] }, { "cell_type": "code", "execution_count": 9, "id": "24554147-0a3c-4bb5-b2d4-ab6cd75f0d07", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(|Q| x k) x word_emb_size\n", "(50, 300)\n" ] } ], "source": [ "all_res_words = res_k.flatten()\n", "res_word_emb = word_emb[all_res_words]\n", "print(\"(|Q| x k) x word_emb_size\")\n", "print(res_word_emb.shape)" ] }, { "cell_type": "code", "execution_count": 10, "id": "0add5d0d-ed1f-419c-a66d-b906ac96c744", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [], "source": [ "pca = PCA(n_components=3) #Perform 3d-PCA\n", "word_emb_pca = pca.fit_transform(res_word_emb)" ] }, { "cell_type": "code", "execution_count": 11, "id": "806c5683-92a4-413f-b616-cbd85826852c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " pca_x pca_y pca_z word query\n", "0 -0.951780 -0.588461 0.546893 kings king\n", "1 -1.366599 -0.059902 0.103550 queen king\n", "2 -2.038808 -0.398816 -0.404128 monarch king\n", "3 -1.730922 -0.289503 -0.157777 crown_prince king\n", "4 -1.596841 -0.419770 -0.252166 prince king\n" ] } ], "source": [ "pca_df = pd.DataFrame(word_emb_pca, columns=[\"pca_x\", \"pca_y\", \"pca_z\"])\n", "\n", "pca_df[\"word\"] = res_k.flatten()\n", "\n", "labels = np.array([queries]).repeat(top_k)\n", "pca_df[\"query\"] = labels\n", "\n", "print(pca_df.head())" ] }, { "cell_type": "code", "execution_count": 12, "id": "d05b1a05-8dea-4e25-a469-96499d4ffeda", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "hovertemplate": "query=king
pca_x=%{x}
pca_y=%{y}
pca_z=%{z}
word=%{text}", "legendgroup": "king", "marker": { "color": "#636efa", "opacity": 0.7, "symbol": "circle" }, "mode": "markers+text", "name": "king", "scene": "scene", "showlegend": true, "text": [ "kings", "queen", "monarch", "crown_prince", "prince", "sultan", "ruler", "princes", "throne", "royal" ], "type": "scatter3d", "x": [ -0.9517804384231567, -1.366599440574646, -2.0388078689575195, -1.7309224605560303, -1.5968408584594727, -1.369072437286377, -1.1658663749694824, -1.6947394609451294, -1.8148521184921265, -1.7376943826675415 ], "y": [ -0.5884613990783691, -0.0599016398191452, -0.39881592988967896, -0.2895025908946991, -0.41976988315582275, -0.6122424006462097, -0.153344064950943, -0.7690215706825256, -0.6006054878234863, -0.3387441635131836 ], "z": [ 0.5468930006027222, 0.1035504937171936, -0.4041284918785095, -0.15777726471424103, -0.25216609239578247, 0.08896508067846298, 0.0244793388992548, -0.09496109187602997, -0.048294227570295334, -0.12886901199817657 ] }, { "hovertemplate": "query=queen
pca_x=%{x}
pca_y=%{y}
pca_z=%{z}
word=%{text}", "legendgroup": "queen", "marker": { "color": "#EF553B", "opacity": 0.7, "symbol": "circle" }, "mode": "markers+text", "name": "queen", "scene": "scene", "showlegend": true, "text": [ "queens", "princess", "king", "monarch", "Queen", "princesses", "royal", "prince", "duchess", "Queen_Elizabeth_II" ], "type": "scatter3d", "x": [ -0.9125000834465027, -1.4164530038833618, -1.3207920789718628, -2.0388078689575195, -0.8670774698257446, -1.2828762531280518, -1.7376943826675415, -1.5968408584594727, -1.2784538269042969, -1.5294901132583618 ], "y": [ -0.14129917323589325, 0.03801686316728592, -0.5229200720787048, -0.39881622791290283, 0.03786739706993103, -0.008012920618057251, -0.3387441635131836, -0.4197693467140198, -0.17408403754234314, -0.32198742032051086 ], "z": [ 0.5540258884429932, -0.0528465136885643, 0.48819461464881897, -0.4041287899017334, -0.2179393619298935, 0.03382590413093567, -0.12886901199817657, -0.2521649897098541, -0.6913980841636658, -0.7494331002235413 ] }, { "hovertemplate": "query=italy
pca_x=%{x}
pca_y=%{y}
pca_z=%{z}
word=%{text}", "legendgroup": "italy", "marker": { "color": "#00cc96", "opacity": 0.7, "symbol": "circle" }, "mode": "markers+text", "name": "italy", "scene": "scene", "showlegend": true, "text": [ "french", "Italy", "i", "haha", "Cagliari", "india", "dont", "thats", "mr", "lol" ], "type": "scatter3d", "x": [ 0.9309588670730591, 1.462249994277954, 0.9464520215988159, 0.9751124978065491, 2.0935122966766357, 0.9193646311759949, 0.9844239354133606, 0.9996874332427979, 0.5211668610572815, 0.9143239259719849 ], "y": [ -0.8589074611663818, -1.0839223861694336, -0.14867663383483887, -0.36846986413002014, -1.414186716079712, -0.6561701893806458, -0.1323813498020172, -0.11387570947408676, -0.2994327247142792, -0.16725479066371918 ], "z": [ 1.1040490865707397, -1.147091269493103, 1.676418423652649, 2.1434600353240967, -1.0863116979599, 1.7296453714370728, 2.1569983959198, 2.1118345260620117, 2.0187184810638428, 2.325324773788452 ] }, { "hovertemplate": "query=Italy
pca_x=%{x}
pca_y=%{y}
pca_z=%{z}
word=%{text}", "legendgroup": "Italy", "marker": { "color": "#ab63fa", "opacity": 0.7, "symbol": "circle" }, "mode": "markers+text", "name": "Italy", "scene": "scene", "showlegend": true, "text": [ "Italian", "Sicily", "Italians", "ITALY", "Spain", "Bologna", "Italia", "France", "Milan", "Romania" ], "type": "scatter3d", "x": [ 1.4158889055252075, 1.3687331676483154, 1.2609578371047974, 1.971722960472107, 0.8236953616142273, 1.8961466550827026, 1.5968090295791626, 0.8499287366867065, 1.3787188529968262, 1.2337576150894165 ], "y": [ -1.0573140382766724, -1.0023144483566284, -1.0882835388183594, -1.4329508543014526, -0.7145146727561951, -1.0877175331115723, -0.8948194980621338, -0.5499246120452881, -0.9163414239883423, -0.5847146511077881 ], "z": [ -1.1201683282852173, -1.2993834018707275, -0.9582164287567139, -0.748938262462616, -0.5381762385368347, -1.2336772680282593, -0.6072587370872498, -0.48449867963790894, -0.8369559645652771, -0.4665997624397278 ] }, { "hovertemplate": "query=nurse
pca_x=%{x}
pca_y=%{y}
pca_z=%{z}
word=%{text}", "legendgroup": "nurse", "marker": { "color": "#FFA15A", "opacity": 0.7, "symbol": "circle" }, "mode": "markers+text", "name": "nurse", "scene": "scene", "showlegend": true, "text": [ "registered_nurse", "nurses", "nurse_practitioner", "midwife", "Nurse", "nursing", "doctor", "medic", "pharmacist", "paramedic" ], "type": "scatter3d", "x": [ 0.5844143629074097, 0.41321346163749695, 0.5576409697532654, 0.2616342604160309, 0.5629440546035767, 0.6022177338600159, 0.23832714557647705, 0.4809919595718384, 0.6038249135017395, 0.5993444919586182 ], "y": [ 2.6503474712371826, 1.9749938249588013, 2.792224407196045, 2.257770299911499, 1.9113136529922485, 1.7936431169509888, 1.7009142637252808, 1.8391629457473755, 2.057131767272949, 2.074826717376709 ], "z": [ -0.5991157293319702, -0.24604666233062744, -0.46162161231040955, -0.33090874552726746, -0.10178092867136002, -0.2647252678871155, -0.15270692110061646, -0.2492000162601471, -0.249731183052063, -0.3402983546257019 ] } ], "layout": { "autosize": true, "legend": { "title": { "text": "query" }, "tracegroupgap": 0 }, "scene": { "aspectmode": "auto", "aspectratio": { "x": 1.036947374917748, "y": 1.0602480496098923, "z": 0.9095693162525992 }, "camera": { "center": { "x": 0, "y": 0, "z": 0 }, "eye": { "x": 0.9545193387358788, "y": 0.9545193387358786, "z": 0.9545193387358786 }, "projection": { "type": "perspective" }, "up": { "x": 0, "y": 0, "z": 1 } }, "domain": { "x": [ 0, 1 ], "y": [ 0, 1 ] }, "xaxis": { "title": { "text": "pca_x" }, "type": "linear" }, "yaxis": { "title": { "text": "pca_y" }, "type": "linear" }, "zaxis": { "title": { "text": "pca_z" }, "type": "linear" } }, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "3d-PCA representation of word embeddings" } } }, "image/png": "", "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "px.scatter_3d(pca_df, x='pca_x', y='pca_y', z='pca_z', color=\"query\", text=\"word\", opacity=0.7, title=\"3d-PCA representation of word embeddings\")" ] }, { "cell_type": "markdown", "id": "7c76ab53-70dc-4324-97ca-358b4d962f5e", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Word embedding evaluation\n", "\n", "- There are two main types of evaluation:\n", " - intrisic evaluation: evaluate word embeddings without a downstream task\n", " - word similarity benchmarks\n", " - word analogy benchmarks\n", " - extrinsic evaluation: evaluate word embeddings on a downstream task" ] }, { "cell_type": "markdown", "id": "bc5e5ef0-9a9b-4952-9c74-26130634ae7b", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "#### Word Similarity Benchmarks\n", "- Word similarity benchmarks, such as WS353, contain word pairs and a human-given similarity score\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "a006fbad-3740-40ea-b2e9-9a85c92ef32e", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Word 1Word 2Human (mean)
5computerinternet7.58
286seafoodfood8.34
127planetstar8.45
263profitwarning3.88
304environmentecology8.81
\n", "
" ], "text/plain": [ " Word 1 Word 2 Human (mean)\n", "5 computer internet 7.58\n", "286 seafood food 8.34\n", "127 planet star 8.45\n", "263 profit warning 3.88\n", "304 environment ecology 8.81" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ws353_df = pd.read_csv(datapath('wordsim353.tsv'), sep=\"\\t\", skiprows=1).rename(columns={\"# Word 1\": \"Word 1\"})\n", "ws353_df.sample(5)" ] }, { "cell_type": "markdown", "id": "8a2d8328-745e-4afe-b192-3b71846416ca", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "- To evaluate word embeddings, we need to do three steps:\n", "1. For each pair $(w_{i_{1}}, w_{i_{2}})$ we get the embeddings $(\\mathbf{e}_{w_{i_{1}}}, \\mathbf{e}_{w_{i_{2}}})$\n", "2. For each pair we compute the cosine similarity between its word embeddings $s_i = \\cos(\\mathbf{e}_{w_{i_{1}}}, \\mathbf{e}_{w_{i_{2}}})$\n", "3. We compute a correlation score ([Pearson's $r$](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient) or [Spearman's $\\rho$](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient)) between the human given scores $h_i$ and the cosine similarities $s_i$\n", " - the higher the score, the better!" ] }, { "cell_type": "markdown", "id": "2c939ac8-8388-4419-bf62-18a83602b318", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "##### Evaluating word similarity with gensim\n", "\n", "- gensim allows us to do everything with the [```evaluate_word_pairs```](https://radimrehurek.com/gensim/models/keyedvectors.html#gensim.models.keyedvectors.KeyedVectors.evaluate_word_pairs) function" ] }, { "cell_type": "code", "execution_count": 14, "id": "c827290d-3880-4919-b3f7-a87734155f92", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(PearsonRResult(statistic=0.6525349640723466, pvalue=3.3734155032900286e-44),\n", " SpearmanrResult(correlation=0.7000166486272194, pvalue=2.86866666051422e-53),\n", " 0.0)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "word_emb.evaluate_word_pairs(datapath('wordsim353.tsv'), case_insensitive=False)" ] }, { "cell_type": "markdown", "id": "0cdce175-12e1-4eaf-acac-1dfadc83ac1c", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "#### Word analogy benchmarks" ] }, { "cell_type": "markdown", "id": "36aa6616-4a20-4556-9eef-39f3726e435a", "metadata": {}, "source": [ "- When doing word analogy resolution with word embeddings, we want to solve equations such as\n", "\n", " *man : king = woman : x*\n", "\n", "- word2vec paper shows that word2vec embeddings can solve (some) of these equations by algebric operations:\n", "1. Get $\\mathbf{e}_x = \\mathbf{e}_{king} - \\mathbf{e}_{man} + \\mathbf{e}_{woman}$\n", "2. Check if $NN_{V}(\\mathbf{e}_x) = \\text{queen}$\n", " " ] }, { "cell_type": "markdown", "id": "fad14c49-be5f-43cb-bf7e-6b369e84158e", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "##### Evaluating word analogies with gensim\n", "- gensim provides us with a [```most_similar```]() function\n", "- It has several arguments, the most important are:\n", " - ```positive``` : list of words that should be summed together\n", " - ```negative``` : list of words that should be subtracted\n", "- In formulas, this function computes:\n", "$$ \\mathbf{e}_x = \\sum_{i \\in \\text{pos}} \\mathbf{e}_i - \\sum_{i \\in \\text{neg}} \\mathbf{e}_i $$\n", "- And then returns nearest neighbours of $\\mathbf{e}_x$" ] }, { "cell_type": "code", "execution_count": 15, "id": "18d28a9c-7e6d-460d-82ea-0e3c9d3df32b", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('queen', 0.7118192911148071), ('monarch', 0.6189674735069275), ('princess', 0.5902431011199951), ('crown_prince', 0.549946129322052), ('prince', 0.5377321243286133), ('kings', 0.5236843824386597), ('queens', 0.5181134343147278), ('sultan', 0.5098593235015869), ('monarchy', 0.5087411403656006), ('royal_palace', 0.5087166428565979)]\n", "[('Walkman', 0.581480860710144), ('MP3_player', 0.5763883590698242), ('MP3', 0.5520824193954468), ('Panasonic', 0.5468560457229614), ('Blu_ray_disc', 0.5435828566551208), ('JVC', 0.525976836681366), ('camcorder', 0.5257487297058105), ('Sony_PSP', 0.5226278305053711), ('PlayStation_Portable', 0.5171500444412231), ('Blu_ray', 0.5171388983726501)]\n" ] } ], "source": [ "print(word_emb.most_similar(positive=[\"king\", \"woman\"], negative=[\"man\"], restrict_vocab=100000))\n", "print(word_emb.most_similar(positive=[\"iPod\", \"Sony\"], negative=[\"Apple\"], restrict_vocab=100000))" ] }, { "cell_type": "code", "execution_count": 16, "id": "27d1d8b9-e943-4528-bf3d-ccf58da4867c", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ": capital-common-countries\n", "Athens Greece Baghdad Iraq\n", "Athens Greece Bangkok Thailand\n", "Athens Greece Beijing China\n", "Athens Greece Berlin Germany\n", "Athens Greece Bern Switzerland\n", "Athens Greece Cairo Egypt\n", "Athens Greece Canberra Australia\n", "Athens Greece Hanoi Vietnam\n", "Athens Greece Havana Cuba\n", "Athens Greece Helsinki Finland\n", "Athens Greece Islamabad Pakistan\n", "Athens Greece Kabul Afghanistan\n", "Athens Greece London England\n", "Athens Greece Madrid Spain\n", "\n" ] } ], "source": [ "f = open(datapath('questions-words.txt'))\n", "print(\"\".join(f.readlines()[:15]))\n", "f.close()" ] }, { "cell_type": "code", "execution_count": 17, "id": "faef7b21-d033-41f4-9f00-884d31ce68ca", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy 0.7401448525607863\n", "dict_keys(['section', 'correct', 'incorrect'])\n", "Correct [('ATHENS', 'GREECE', 'BANGKOK', 'THAILAND'), ('ATHENS', 'GREECE', 'BEIJING', 'CHINA'), ('ATHENS', 'GREECE', 'BERLIN', 'GERMANY'), ('ATHENS', 'GREECE', 'BERN', 'SWITZERLAND'), ('ATHENS', 'GREECE', 'CAIRO', 'EGYPT')]\n", "Incorrect [('ATHENS', 'GREECE', 'BAGHDAD', 'IRAQ'), ('ATHENS', 'GREECE', 'HANOI', 'VIETNAM'), ('ATHENS', 'GREECE', 'KABUL', 'AFGHANISTAN'), ('ATHENS', 'GREECE', 'LONDON', 'ENGLAND'), ('BAGHDAD', 'IRAQ', 'BERN', 'SWITZERLAND')]\n" ] } ], "source": [ "accuracy, results = word_emb.evaluate_word_analogies(datapath('questions-words.txt'))\n", "print(f\"Accuracy {accuracy}\")\n", "print(results[0].keys())\n", "print(f\"Correct {results[0]['correct'][:5]}\")\n", "print(f\"Incorrect {results[0]['incorrect'][:5]}\")" ] }, { "cell_type": "markdown", "id": "a157c61e-9753-43b2-b728-13bdeea08b90", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "### It's your turn! Go ahead with *Task 1.*\n", "\n", "![]() " ] }, { "cell_type": "markdown", "id": "33dcab8d-23ba-4da0-b870-cd25b24db593", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "#### Task 1 \n", "\n", "Implement instrisic evaluation using wordsim353 benchamark. For computing correlations (step 3) use [```spearmanr```](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html) and [```pearsonr```](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html) from ```scipy.stats```" ] }, { "cell_type": "code", "execution_count": null, "id": "0bf06b41-aa29-4943-90de-b3d893180d19", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [], "source": [ "#TODO: implement WS353 evaluation benchmark in the three steps below.\n", "\n", "# Step 0: (re)load the data\n", "\n", "ws353_df = pd.read_csv(datapath('wordsim353.tsv'), sep=\"\\t\", skiprows=1).rename(columns={\"# Word 1\": \"Word 1\"})\n", "\n", "# Step 1: Get embeddings (use ws353_df defined above)\n", "\n", "# Step 2: Compute Cosine similarities\n", "\n", "\n", "# Step 3: Compute correlations\n" ] }, { "cell_type": "markdown", "id": "25134a35-3180-457b-bcf9-93268a6032ee", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "## Part 2 : Pretraining your own embeddings\n", "\n", "- Up to know we have used embeddings that someone else trained for us\n", "- What if you want to pretrain your own embeddings for you domain or task of interest?\n", "- The first thing we need is data!" ] }, { "cell_type": "code", "execution_count": 19, "id": "53f7d388-2c8e-48d5-a66d-fa001fa48cd3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Hundreds of people have been forced to vacate their homes in the Southern Highlands of New South Wales as strong winds today pushed a huge bushfire towards the town of Hill Top. A new blaze near Goulburn, south-west of Sydney, has forced the closure of the Hume Highway. At about 4:00pm AEDT, a marked deterioration in the weather as a storm cell moved east across the Blue Mountains forced authorities to make a decision to evacuate people from homes in outlying streets at Hill Top in the New South Wales southern highlands. An estimated 500 residents have left their homes for nearby Mittagong. The New South Wales Rural Fire Service says the weather conditions which caused the fire to burn in a finger formation have now eased and about 60 fire units in and around Hill Top are optimistic of defending all properties. As more than 100 blazes burn on New Year's Eve in New South Wales, fire crews have been called to new fire at Gunning, south of Goulburn. While few details are available at this stage, fire authorities says it has closed the Hume Highway in both directions. Meanwhile, a new fire in Sydney's west is no longer threatening properties in the Cranebrook area. Rain has fallen in some parts of the Illawarra, Sydney, the Hunter Valley and the north coast. But the Bureau of Meteorology's Claire Richards says the rain has done little to ease any of the hundred fires still burning across the state. \"The falls have been quite isolated in those areas and generally the falls have been less than about five millimetres,\" she said. \"In some places really not significant at all, less than a millimetre, so there hasn't been much relief as far as rain is concerned. \"In fact, they've probably hampered the efforts of the firefighters more because of the wind gusts that are associated with those thunderstorms.\" \n", " ['hundreds', 'of', 'people', 'have', 'been', 'forced', 'to', 'vacate', 'their', 'homes', 'in', 'the', 'southern', 'highlands', 'of', 'new', 'south', 'wales', 'as', 'strong', 'winds', 'today', 'pushed', 'huge', 'bushfire', 'towards', 'the', 'town', 'of', 'hill', 'top', 'new', 'blaze', 'near', 'goulburn', 'south', 'west', 'of', 'sydney', 'has', 'forced', 'the', 'closure', 'of', 'the', 'hume', 'highway', 'at', 'about', 'pm', 'aedt', 'marked', 'deterioration', 'in', 'the', 'weather', 'as', 'storm', 'cell', 'moved', 'east', 'across', 'the', 'blue', 'mountains', 'forced', 'authorities', 'to', 'make', 'decision', 'to', 'evacuate', 'people', 'from', 'homes', 'in', 'outlying', 'streets', 'at', 'hill', 'top', 'in', 'the', 'new', 'south', 'wales', 'southern', 'highlands', 'an', 'estimated', 'residents', 'have', 'left', 'their', 'homes', 'for', 'nearby', 'mittagong', 'the', 'new', 'south', 'wales', 'rural', 'fire', 'service', 'says', 'the', 'weather', 'conditions', 'which', 'caused', 'the', 'fire', 'to', 'burn', 'in', 'finger', 'formation', 'have', 'now', 'eased', 'and', 'about', 'fire', 'units', 'in', 'and', 'around', 'hill', 'top', 'are', 'optimistic', 'of', 'defending', 'all', 'properties', 'as', 'more', 'than', 'blazes', 'burn', 'on', 'new', 'year', 'eve', 'in', 'new', 'south', 'wales', 'fire', 'crews', 'have', 'been', 'called', 'to', 'new', 'fire', 'at', 'gunning', 'south', 'of', 'goulburn', 'while', 'few', 'details', 'are', 'available', 'at', 'this', 'stage', 'fire', 'authorities', 'says', 'it', 'has', 'closed', 'the', 'hume', 'highway', 'in', 'both', 'directions', 'meanwhile', 'new', 'fire', 'in', 'sydney', 'west', 'is', 'no', 'longer', 'threatening', 'properties', 'in', 'the', 'cranebrook', 'area', 'rain', 'has', 'fallen', 'in', 'some', 'parts', 'of', 'the', 'illawarra', 'sydney', 'the', 'hunter', 'valley', 'and', 'the', 'north', 'coast', 'but', 'the', 'bureau', 'of', 'meteorology', 'claire', 'richards', 'says', 'the', 'rain', 'has', 'done', 'little', 'to', 'ease', 'any', 'of', 'the', 'hundred', 'fires', 'still', 'burning', 'across', 'the', 'state', 'the', 'falls', 'have', 'been', 'quite', 'isolated', 'in', 'those', 'areas', 'and', 'generally', 'the', 'falls', 'have', 'been', 'less', 'than', 'about', 'five', 'millimetres', 'she', 'said', 'in', 'some', 'places', 'really', 'not', 'significant', 'at', 'all', 'less', 'than', 'millimetre', 'so', 'there', 'hasn', 'been', 'much', 'relief', 'as', 'far', 'as', 'rain', 'is', 'concerned', 'in', 'fact', 'they', 've', 'probably', 'hampered', 'the', 'efforts', 'of', 'the', 'firefighters', 'more', 'because', 'of', 'the', 'wind', 'gusts', 'that', 'are', 'associated', 'with', 'those', 'thunderstorms']\n" ] } ], "source": [ "corpus = open(datapath('lee_background.cor'))\n", "sample = corpus.readline()\n", "print(sample, utils.simple_preprocess(sample))" ] }, { "cell_type": "code", "execution_count": 20, "id": "7d120bd6-c8b8-4297-86f5-479403eee326", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [], "source": [ "class MyCorpus:\n", " \"\"\"An iterator that yields sentences (lists of str).\"\"\"\n", "\n", " def __iter__(self):\n", " corpus_path = datapath('lee_background.cor')\n", " for line in open(corpus_path):\n", " # assume there's one document per line, tokens separated by whitespace\n", " yield utils.simple_preprocess(line)" ] }, { "cell_type": "markdown", "id": "e4fb828c-ce09-4cc2-af29-247ea5564ae9", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Let's pretrain our own embeddings\n", "\n", "- We will use the [```Word2Vec```](https://radimrehurek.com/gensim/models/word2vec.html#gensim.models.word2vec.Word2Vec) class from ```gensim.models```\n", "- Let's look at the most important parameters" ] }, { "cell_type": "code", "execution_count": 21, "id": "61bd239c-d7cb-4720-a068-65b13a5a154d", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Word2Vec\n" ] } ], "source": [ "model = Word2Vec(sentences=MyCorpus(), \n", " min_count=3, # ignore all words with freq < min_count\n", " vector_size=200, # dimensionality of the vectors\n", " sg=1, # 1 for skip-gram, 0 for CBOW\n", " epochs=10, # num_epochs\n", " alpha=0.025, # initial learning rate\n", " batch_words=10000, # batch size\n", " window=5, # window size for context words\n", " negative=10, # number of negatives for negative sampling\n", " ns_exponent=0.75 # exponent of the sampling distribution\n", " )\n", "print(model)\n", "word_emb_lee = model.wv # wv attribute contains word embeddings" ] }, { "cell_type": "markdown", "id": "e912f2dd-6f65-4b0b-9637-1b1dc9aa5181", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Saving and loading your embeddings\n", "\n", "- Saving or loading the full model (i.e. embeddings + hyperparameters)\n", " - This allows to resume training" ] }, { "cell_type": "code", "execution_count": 22, "id": "98ff7d27-7a4c-4ead-b3a7-8fc026b798b2", "metadata": { "tags": [] }, "outputs": [], "source": [ "save_path = \"word2vee_lee.model\"\n", "model.save(save_path)\n", "model_reloaded = Word2Vec.load(save_path)" ] }, { "cell_type": "markdown", "id": "b9c71077-31ea-4ef4-9543-ee694c663ce3", "metadata": {}, "source": [ "- Saving or loading **only** word embeddings\n", " - This does **NOT** allow to resume training" ] }, { "cell_type": "code", "execution_count": 23, "id": "31ece60e-ea65-4198-9334-e5e5e663fbbb", "metadata": { "tags": [] }, "outputs": [], "source": [ "save_path = \"word2vee_lee.emb\"\n", "model.wv.save(save_path)\n", "emb_reloaded = KeyedVectors.load(save_path)" ] }, { "cell_type": "markdown", "id": "c5e9f540-9f77-4f62-b89e-ec4356251833", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "## Part 3 : Extrinsic evaluation of word embeddings\n", "\n", "- Up to know we have evaluated words embeddings intrisically\n", "- Let's try to see how they fare in a real world task\n", "- We will use them to solve a spam classification task" ] }, { "cell_type": "markdown", "id": "3d94b00d", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Please now run all cells below!\n", " - You will need them for **Task 2**\n", "- Remember to put ```SMSSpamCollection.tsv``` in the same folder as this notebook\n", " - Or upload it if you're using Colab" ] }, { "cell_type": "code", "execution_count": 24, "id": "6f649677-9e95-4de2-b75f-84bb3b28eb55", "metadata": { "tags": [] }, "outputs": [], "source": [ "spam_df = pd.read_csv(\"SMSSpamCollection.tsv\", sep=\"\\t\", header=None, names=[\"label\", \"text\"])" ] }, { "cell_type": "code", "execution_count": 25, "id": "e787e4f9-5f4c-4943-8474-557521805313", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labeltext
00Go until jurong point, crazy.. Available only ...
10Ok lar... Joking wif u oni...
21Free entry in 2 a wkly comp to win FA Cup fina...
30U dun say so early hor... U c already then say...
40Nah I don't think he goes to usf, he lives aro...
.........
55671This is the 2nd time we have tried 2 contact u...
55680Will ü b going to esplanade fr home?
55690Pity, * was in mood for that. So...any other s...
55700The guy did some bitching but I acted like i'd...
55710Rofl. Its true to its name
\n", "

5572 rows × 2 columns

\n", "
" ], "text/plain": [ " label text\n", "0 0 Go until jurong point, crazy.. Available only ...\n", "1 0 Ok lar... Joking wif u oni...\n", "2 1 Free entry in 2 a wkly comp to win FA Cup fina...\n", "3 0 U dun say so early hor... U c already then say...\n", "4 0 Nah I don't think he goes to usf, he lives aro...\n", "... ... ...\n", "5567 1 This is the 2nd time we have tried 2 contact u...\n", "5568 0 Will ü b going to esplanade fr home?\n", "5569 0 Pity, * was in mood for that. So...any other s...\n", "5570 0 The guy did some bitching but I acted like i'd...\n", "5571 0 Rofl. Its true to its name\n", "\n", "[5572 rows x 2 columns]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's do one-hot encoding of the labels\n", "label_encoder = LabelEncoder()\n", "spam_df[\"label\"] = label_encoder.fit_transform(spam_df[\"label\"])\n", "spam_df" ] }, { "cell_type": "markdown", "id": "c9013ea4-1faf-4cc4-a194-b67ed8c1bb1c", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Building a classification model \n", "- We want to use a standard ML approach\n", "- We will first preprocess the text:\n", " - lowercasing\n", " - tokenization\n", " - stopword removal\n", "- After this, we will create a sentence embedding of each SMS as the average of word embeddings in that sentence" ] }, { "cell_type": "code", "execution_count": 26, "id": "2a89d9bf-92e6-4d0d-88d9-6640bade38f7", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package stopwords to\n", "[nltk_data] C:\\Users\\Tommaso\\AppData\\Roaming\\nltk_data...\n", "[nltk_data] Package stopwords is already up-to-date!\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nltk.download('stopwords')" ] }, { "cell_type": "code", "execution_count": 27, "id": "b564425d-6944-4591-ae9c-6038040f62b2", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labeltextpreprocessed_text
00Go until jurong point, crazy.. Available only ...[go, jurong, point, crazy, available, bugis, g...
10Ok lar... Joking wif u oni...[ok, lar, joking, wif, oni]
21Free entry in 2 a wkly comp to win FA Cup fina...[free, entry, wkly, comp, win, fa, cup, final,...
30U dun say so early hor... U c already then say...[dun, say, early, hor, already, say]
40Nah I don't think he goes to usf, he lives aro...[nah, think, goes, usf, lives, around, though]
............
55671This is the 2nd time we have tried 2 contact u...[nd, time, tried, contact, pound, prize, claim...
55680Will ü b going to esplanade fr home?[going, esplanade, fr, home]
55690Pity, * was in mood for that. So...any other s...[pity, mood, suggestions]
55700The guy did some bitching but I acted like i'd...[guy, bitching, acted, like, interested, buyin...
55710Rofl. Its true to its name[rofl, true, name]
\n", "

5572 rows × 3 columns

\n", "
" ], "text/plain": [ " label text \\\n", "0 0 Go until jurong point, crazy.. Available only ... \n", "1 0 Ok lar... Joking wif u oni... \n", "2 1 Free entry in 2 a wkly comp to win FA Cup fina... \n", "3 0 U dun say so early hor... U c already then say... \n", "4 0 Nah I don't think he goes to usf, he lives aro... \n", "... ... ... \n", "5567 1 This is the 2nd time we have tried 2 contact u... \n", "5568 0 Will ü b going to esplanade fr home? \n", "5569 0 Pity, * was in mood for that. So...any other s... \n", "5570 0 The guy did some bitching but I acted like i'd... \n", "5571 0 Rofl. Its true to its name \n", "\n", " preprocessed_text \n", "0 [go, jurong, point, crazy, available, bugis, g... \n", "1 [ok, lar, joking, wif, oni] \n", "2 [free, entry, wkly, comp, win, fa, cup, final,... \n", "3 [dun, say, early, hor, already, say] \n", "4 [nah, think, goes, usf, lives, around, though] \n", "... ... \n", "5567 [nd, time, tried, contact, pound, prize, claim... \n", "5568 [going, esplanade, fr, home] \n", "5569 [pity, mood, suggestions] \n", "5570 [guy, bitching, acted, like, interested, buyin... \n", "5571 [rofl, true, name] \n", "\n", "[5572 rows x 3 columns]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# lowercase, tokenize and stopword removal\n", "stop_words = set(stopwords.words('english'))\n", "spam_df[\"preprocessed_text\"] = spam_df[\"text\"].apply(lambda sentence: [word for word in utils.simple_preprocess(sentence) if word not in stop_words])\n", "spam_df" ] }, { "cell_type": "code", "execution_count": 28, "id": "eeb313e3-fbe0-4026-b423-92bd1a3a5934", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Tommaso\\anaconda3\\envs\\pd_nlp\\lib\\site-packages\\numpy\\core\\fromnumeric.py:3432: RuntimeWarning:\n", "\n", "Mean of empty slice.\n", "\n" ] } ], "source": [ "# Create sentence embeddings\n", "spam_df[\"sent_emb\"] = spam_df[\"preprocessed_text\"].apply(lambda tok_sentence: np.mean([word_emb[word] for word in tok_sentence if word in word_emb.key_to_index], axis=0))" ] }, { "cell_type": "code", "execution_count": 29, "id": "0830e39a-89bc-4069-a9e8-8ff3e5ec18c1", "metadata": { "tags": [] }, "outputs": [], "source": [ "spam_df = spam_df.dropna()" ] }, { "cell_type": "code", "execution_count": 30, "id": "a0db16f0-5cb1-4394-becc-e5ae9904c02f", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [], "source": [ "all_features = spam_df.drop(columns=\"label\")\n", "features_train, features_test, y_train, y_test = train_test_split(all_features , spam_df[\"label\"], test_size=0.2, random_state=2023, stratify=spam_df[\"label\"])" ] }, { "cell_type": "code", "execution_count": 31, "id": "fecd00c5-87c6-4538-9b1b-63bbfb15c151", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4444, 3) (1111, 3)\n" ] } ], "source": [ "print(features_train.shape, features_test.shape)" ] }, { "cell_type": "markdown", "id": "2cc6c573-7cdb-4458-946d-82c9afcde48e", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Logistic regression classifier on top of these sentence embeddings" ] }, { "cell_type": "code", "execution_count": 32, "id": "0f077d0c-f175-4999-b1c8-34f0766b5b7e", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
LogisticRegressionCV(cv=5, max_iter=1000)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LogisticRegressionCV(cv=5, max_iter=1000)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logreg_model = LogisticRegressionCV(Cs=10, cv=5, penalty='l2', max_iter=1000)\n", "sent_emb_train = np.stack(features_train[\"sent_emb\"]) # shape: train_size x 300\n", "logreg_model.fit(sent_emb_train, y_train) # 5-fold GridSearchCV followed by training of full model " ] }, { "cell_type": "code", "execution_count": 33, "id": "4afb2d49-3d55-48fc-a955-d7f3b22d6394", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of the model 0.9567956795679567\n" ] } ], "source": [ "sent_emb_test = np.stack(features_test[\"sent_emb\"])\n", "print(f\"Accuracy of the model {logreg_model.score(sent_emb_test, y_test)}\")" ] }, { "cell_type": "code", "execution_count": 34, "id": "4f12baaa-a080-4cab-a6ce-d1bc0e92681c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " ham 0.97 0.98 0.98 962\n", " spam 0.87 0.79 0.83 149\n", "\n", " accuracy 0.96 1111\n", " macro avg 0.92 0.89 0.90 1111\n", "weighted avg 0.96 0.96 0.96 1111\n", "\n" ] } ], "source": [ "print(classification_report(y_test, logreg_model.predict(sent_emb_test), target_names=label_encoder.classes_))" ] }, { "cell_type": "markdown", "id": "e7f76cc0-7d97-44cc-9101-0f36191f4e49", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "### Sneak-peek: using gensim embeddings in PyTorch" ] }, { "cell_type": "code", "execution_count": 35, "id": "e4d5b376-17dd-49a8-b598-68dbe320aacd", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "tensor([[-7.6660e-02, 1.1035e-01, 3.5352e-01, -7.9102e-02, -5.0049e-02,\n", " -2.9688e-01, 1.0938e-01, -3.5938e-01, -8.7402e-02, -7.0312e-02,\n", " 2.0801e-01, -2.4512e-01, -5.5664e-02, 2.4219e-01, 2.3560e-02,\n", " -8.6670e-03, 2.6855e-02, 4.0234e-01, 1.7480e-01, -1.6602e-02,\n", " -2.0410e-01, 5.0000e-01, -8.9844e-02, -1.4355e-01, 5.5420e-02,\n", " 9.0820e-02, 1.1426e-01, 1.5430e-01, 1.3477e-01, 2.2656e-01,\n", " 4.4189e-02, 3.7109e-02, -1.1621e-01, -1.1328e-01, -1.6479e-02,\n", " -1.2695e-01, 2.9883e-01, -1.2598e-01, 1.0303e-01, 3.1641e-01,\n", " 1.2665e-03, -8.8379e-02, 1.2695e-01, 1.5820e-01, -7.1777e-02,\n", " -2.1094e-01, 3.1641e-01, -2.0801e-01, 9.6893e-04, 3.2422e-01,\n", " 7.1289e-02, 7.1289e-02, -9.2285e-02, -2.2705e-02, 9.5703e-02,\n", " -2.9883e-01, -6.5918e-02, -7.0801e-02, -7.4219e-02, -2.3535e-01,\n", " -2.8320e-01, -2.0215e-01, -1.6211e-01, 2.2070e-01, -2.3682e-02,\n", " -1.0645e-01, 1.9653e-02, -5.9082e-02, 2.3730e-01, 2.9785e-02,\n", " 2.8711e-01, 1.0889e-01, -1.9141e-01, -5.2246e-02, 1.4746e-01,\n", " 2.3730e-01, 4.1016e-02, -5.9326e-02, 2.0508e-01, -1.5430e-01,\n", " -2.2461e-01, 3.8330e-02, -8.1055e-02, 2.1387e-01, -1.7944e-02,\n", " -1.9043e-01, -1.2402e-01, 2.7539e-01, 1.0681e-02, -6.2500e-02,\n", " 8.5449e-02, -3.0859e-01, 3.6621e-02, 6.9336e-02, 1.3184e-01,\n", " 7.4219e-02, 1.6016e-01, -5.9326e-02, 1.2634e-02, 2.6758e-01,\n", " -3.5706e-03, 1.0303e-01, 2.9297e-01, 1.1816e-01, 2.0117e-01,\n", " 4.8047e-01, 2.1680e-01, -3.2422e-01, -1.7188e-01, -1.6016e-01,\n", " 1.4746e-01, -6.4392e-03, 2.0264e-02, -1.0791e-01, 9.9121e-02,\n", " 2.7734e-01, -4.3945e-02, 2.3438e-01, -2.9175e-02, -1.3574e-01,\n", " -5.1514e-02, -3.0664e-01, 2.4048e-02, -5.4688e-02, -1.9727e-01,\n", " -1.2695e-01, -3.4375e-01, -1.0742e-02, -1.7285e-01, 4.0771e-02,\n", " 2.7588e-02, -3.2227e-01, -1.4258e-01, -1.6504e-01, -3.8452e-03,\n", " -1.0889e-01, -2.4707e-01, 1.5430e-01, -1.1475e-02, -3.2501e-03,\n", " -2.2339e-02, -1.1035e-01, 3.4961e-01, -5.8105e-02, 1.7773e-01,\n", " 3.1055e-01, 1.8066e-01, -6.1340e-03, 9.1553e-03, -3.2812e-01,\n", " 1.7480e-01, 2.2266e-01, -2.7344e-01, -1.6602e-01, -2.8711e-01,\n", " 2.5977e-01, -5.8350e-02, -3.7354e-02, -9.0332e-02, -8.3008e-02,\n", " -6.2500e-02, 4.1260e-02, 8.4839e-03, -5.8105e-02, 1.0681e-02,\n", " -4.8218e-03, 1.1182e-01, -3.0078e-01, -9.2773e-02, -6.4453e-02,\n", " -9.5215e-02, 2.8320e-01, -1.4453e-01, -1.6406e-01, 1.3574e-01,\n", " -4.3945e-02, 1.7188e-01, -2.1606e-02, -8.1055e-02, 8.8867e-02,\n", " 1.3281e-01, -1.5137e-02, -2.3242e-01, 1.8555e-02, 4.3945e-02,\n", " 3.7500e-01, 7.6172e-02, -9.7656e-02, 1.0681e-03, 1.3672e-01,\n", " 2.5000e-01, -6.1768e-02, 6.2012e-02, 1.6309e-01, -1.9434e-01,\n", " 7.9590e-02, 2.2461e-01, 1.8555e-01, -2.5195e-01, -9.5215e-02,\n", " -2.0508e-01, -1.3184e-01, -5.0293e-02, -3.0664e-01, 5.5176e-02,\n", " -5.7812e-01, 4.9561e-02, -6.7383e-02, -2.1777e-01, -2.1851e-02,\n", " -1.3574e-01, 1.1182e-01, 6.7383e-02, 1.7871e-01, -2.3828e-01,\n", " 1.3184e-01, 9.0942e-03, -2.5024e-03, -3.2812e-01, -7.0312e-02,\n", " -1.1719e-01, -2.7734e-01, -4.1809e-03, -1.0352e-01, -1.6968e-02,\n", " 1.4648e-01, -2.2705e-02, 2.4292e-02, 1.2500e-01, -2.3438e-02,\n", " 9.7656e-02, 2.2827e-02, -1.1572e-01, 5.8350e-02, 1.7969e-01,\n", " -7.9346e-03, 1.2817e-02, -2.2363e-01, 1.0254e-01, -2.2363e-01,\n", " 5.2979e-02, -2.3926e-01, 1.5137e-01, 7.9102e-02, -4.7363e-02,\n", " 7.1411e-03, -2.7734e-01, -1.0938e-01, -4.1211e-01, 5.7373e-02,\n", " 2.0117e-01, 8.4961e-02, 1.5918e-01, 2.2949e-01, 2.5391e-01,\n", " 1.3867e-01, 2.7100e-02, -2.1875e-01, -2.3828e-01, 2.3535e-01,\n", " 5.8594e-02, -1.3770e-01, 1.1670e-01, 1.7188e-01, -5.0049e-03,\n", " 2.7344e-01, -2.9492e-01, 1.6406e-01, 1.0071e-02, 1.5039e-01,\n", " -3.3398e-01, 1.3965e-01, -1.5234e-01, -1.3672e-01, 1.8652e-01,\n", " -3.1836e-01, -4.3213e-02, -1.2207e-01, 1.9824e-01, -1.1328e-01,\n", " -9.2285e-02, 5.9814e-02, 9.0332e-02, 9.8267e-03, -1.2793e-01,\n", " -2.3828e-01, -1.7188e-01, 1.3281e-01, 1.2158e-01, 1.4160e-01,\n", " -1.3281e-01, -1.1621e-01, -2.2949e-02, -1.1670e-01, 2.1289e-01,\n", " 3.7891e-01, 2.5391e-01, 6.2500e-02, -1.8359e-01, 2.6562e-01],\n", " [-2.8906e-01, -8.6426e-02, 2.3145e-01, 3.6377e-02, 4.7363e-02,\n", " -5.6885e-02, 9.4727e-02, -9.2773e-02, -1.7944e-02, 1.8848e-01,\n", " 4.0039e-02, -2.6245e-02, 2.4902e-01, 7.2754e-02, 2.5513e-02,\n", " 3.8818e-02, 1.3770e-01, 4.2383e-01, 2.6953e-01, 1.0437e-02,\n", " -1.3379e-01, 4.7461e-01, -2.5000e-01, -1.0547e-01, -1.0156e-01,\n", " -1.9238e-01, -8.6914e-02, 2.5586e-01, 1.0840e-01, 4.8096e-02,\n", " -6.1035e-03, 2.8801e-04, -1.6406e-01, -2.9907e-02, -2.3071e-02,\n", " -2.3926e-01, 1.5918e-01, -1.0986e-01, 1.3184e-01, 9.9609e-02,\n", " 1.4551e-01, -9.4727e-02, 1.1035e-01, 1.2305e-01, 5.1025e-02,\n", " -3.4766e-01, 2.8320e-01, -1.8750e-01, 7.5195e-02, 4.5703e-01,\n", " 2.2754e-01, 2.1387e-01, -2.8125e-01, -1.3965e-01, 8.4473e-02,\n", " -2.5391e-01, -1.6602e-01, -2.5781e-01, -2.4414e-02, -2.2070e-01,\n", " -1.6016e-01, 5.3955e-02, -2.4414e-01, -1.8188e-02, 1.6724e-02,\n", " 2.9297e-01, 3.4570e-01, 2.0996e-01, 2.2266e-01, -2.6367e-02,\n", " 3.5547e-01, 1.8677e-02, 7.7148e-02, 2.7710e-02, 1.2878e-02,\n", " 1.7090e-01, -1.3086e-01, -2.5391e-01, -3.0078e-01, -1.0693e-01,\n", " -3.3984e-01, -3.5400e-02, 1.6309e-01, 5.0049e-02, -1.7578e-01,\n", " 1.3867e-01, -2.9297e-01, 2.1289e-01, 9.1309e-02, -1.7578e-01,\n", " -1.1572e-01, -6.1035e-02, -4.5654e-02, 1.6699e-01, -3.5400e-02,\n", " 7.8125e-02, 3.4570e-01, -1.2109e-01, -3.4766e-01, -5.2490e-02,\n", " -2.7148e-01, 3.8281e-01, 2.3828e-01, 2.3438e-02, 1.4551e-01,\n", " 4.5508e-01, 2.9297e-01, -4.0771e-02, -2.3535e-01, 2.1240e-02,\n", " -3.6133e-02, -1.0645e-01, -1.6724e-02, 1.3574e-01, 9.7656e-02,\n", " 3.8867e-01, 1.2793e-01, -4.9805e-02, -5.3955e-02, -2.1094e-01,\n", " -2.2949e-01, -2.1118e-02, 1.9141e-01, -2.5391e-02, 4.6875e-02,\n", " -1.6211e-01, -1.1865e-01, 7.4158e-03, 1.9287e-02, 2.2095e-02,\n", " 9.3262e-02, -1.7969e-01, -3.0664e-01, 2.0312e-01, -1.8555e-02,\n", " -2.1289e-01, 5.3406e-03, 1.7969e-01, -3.7109e-01, 7.1289e-02,\n", " -4.2480e-02, -2.7148e-01, 2.0605e-01, -2.9492e-01, -1.1230e-02,\n", " 5.3906e-01, 2.5195e-01, -1.7773e-01, -9.4727e-02, -3.5352e-01,\n", " 1.3477e-01, 2.1484e-01, -4.5117e-01, -1.1572e-01, -2.2168e-01,\n", " 9.1309e-02, 6.2500e-02, -7.6172e-02, -6.4453e-02, -3.3398e-01,\n", " -2.1777e-01, 2.2949e-02, 8.9844e-02, -2.6245e-02, 1.9653e-02,\n", " -2.5586e-01, 1.9727e-01, -1.3281e-01, -4.3457e-02, -8.4961e-02,\n", " 5.2490e-03, 2.1777e-01, -4.1260e-02, 5.4443e-02, 9.9609e-02,\n", " -1.8652e-01, 3.5938e-01, 1.9727e-01, 8.9111e-03, -1.6602e-01,\n", " 7.3242e-02, 2.3926e-01, -4.3359e-01, -4.3701e-02, 2.1191e-01,\n", " 1.7773e-01, 1.0596e-01, 1.7188e-01, 1.8945e-01, 3.6377e-02,\n", " 1.3867e-01, 9.3994e-03, 1.2988e-01, 1.8359e-01, -1.3672e-01,\n", " 2.4316e-01, 3.1250e-02, -5.5420e-02, 1.4746e-01, -1.4160e-01,\n", " -3.9258e-01, -1.8066e-02, -2.0898e-01, -3.0469e-01, -2.6953e-01,\n", " -4.1016e-01, 3.6316e-03, 5.0781e-02, -2.1191e-01, -3.5889e-02,\n", " 1.8555e-02, 2.5000e-01, 5.5908e-02, 1.1780e-02, 1.3281e-01,\n", " -3.3875e-03, -1.0303e-01, 9.7656e-02, 2.8534e-03, -1.5430e-01,\n", " -2.2461e-01, -3.2422e-01, -2.8198e-02, -5.1758e-02, -1.2256e-01,\n", " -2.4170e-02, 1.4453e-01, -1.2354e-01, 6.5918e-02, 2.2339e-02,\n", " 1.1182e-01, -4.9805e-02, 2.0996e-02, -1.4648e-01, 6.8848e-02,\n", " -2.7832e-02, 1.3574e-01, -1.5820e-01, 1.9727e-01, -7.9956e-03,\n", " 5.8594e-02, -2.3145e-01, 4.9805e-01, 1.9897e-02, 8.2520e-02,\n", " -3.0151e-02, -2.8320e-01, -1.6797e-01, -1.2402e-01, 6.3965e-02,\n", " 2.4902e-01, 1.5234e-01, 5.4688e-02, 2.0020e-01, 1.5918e-01,\n", " 3.0078e-01, 2.2559e-01, -3.5645e-02, -2.6758e-01, 2.8320e-01,\n", " 3.3203e-01, 1.8799e-02, 4.2236e-02, -1.4160e-01, -8.5449e-02,\n", " 4.1992e-01, -1.1475e-01, 4.6143e-02, 8.3984e-02, 1.4453e-01,\n", " -7.3730e-02, 3.7891e-01, -1.8555e-01, 3.0151e-02, 1.7090e-01,\n", " -5.3223e-02, 1.2793e-01, -2.4414e-01, 2.1680e-01, -8.8501e-03,\n", " 5.8594e-02, 1.8945e-01, 2.2754e-01, -1.6699e-01, -3.7354e-02,\n", " -3.3447e-02, -3.6523e-01, 1.2891e-01, 5.6458e-04, 4.1016e-01,\n", " -2.3242e-01, 1.1816e-01, -6.7871e-02, -2.2656e-01, 2.6562e-01,\n", " 1.4941e-01, 6.5918e-02, 1.3965e-01, -1.8066e-01, 8.1543e-02],\n", " [-9.7656e-02, 3.1982e-02, 2.5781e-01, -4.1504e-02, 1.0156e-01,\n", " -1.0059e-01, 1.4648e-01, -1.9922e-01, 1.5332e-01, 6.3477e-02,\n", " 8.3984e-02, -3.0078e-01, 6.3477e-02, 2.0898e-01, -2.1191e-01,\n", " 1.8848e-01, -8.3496e-02, 3.2812e-01, 2.7930e-01, -1.4062e-01,\n", " -1.6895e-01, 2.0410e-01, 4.9072e-02, -6.9885e-03, 9.4238e-02,\n", " 9.8419e-04, 3.1250e-02, 2.4805e-01, 3.3594e-01, 2.6367e-01,\n", " 5.6885e-02, 3.0469e-01, 1.2158e-01, -1.9727e-01, 1.7212e-02,\n", " 9.9609e-02, 2.2754e-01, -1.2061e-01, 1.2354e-01, 3.7891e-01,\n", " 2.3682e-02, -1.8652e-01, 6.2988e-02, 1.5234e-01, 3.7354e-02,\n", " -1.6992e-01, 1.0645e-01, -4.9805e-02, -6.2012e-02, 1.6895e-01,\n", " 4.4189e-02, 2.7832e-02, -1.1084e-01, 4.4922e-02, 2.7832e-02,\n", " -4.4531e-01, 3.4912e-02, -6.2256e-02, -3.9307e-02, -2.0117e-01,\n", " -3.0469e-01, -1.0059e-01, -1.6406e-01, 1.5234e-01, 1.1035e-01,\n", " -1.5332e-01, -7.1289e-02, 7.9590e-02, 1.8750e-01, 6.8848e-02,\n", " 2.4414e-01, -6.5613e-04, -1.9141e-01, 3.4912e-02, 1.9775e-02,\n", " -5.4199e-02, 3.3203e-02, -2.0801e-01, 9.8633e-02, -1.9043e-01,\n", " -6.0791e-02, -2.0703e-01, -2.1851e-02, 8.7891e-02, 2.0898e-01,\n", " -2.3633e-01, -9.1797e-02, 2.2656e-01, -3.9307e-02, 9.1309e-02,\n", " 1.1353e-02, -1.5527e-01, 6.7871e-02, -4.9072e-02, 2.5177e-03,\n", " -4.6631e-02, 9.1797e-02, 1.0596e-01, 2.1094e-01, 4.2480e-02,\n", " -4.9561e-02, 1.7676e-01, 3.3203e-01, -4.2236e-02, 2.0312e-01,\n", " 2.9883e-01, 1.2109e-01, -4.8584e-02, -1.4160e-01, -2.5195e-01,\n", " -2.2070e-01, 2.2363e-01, 2.2217e-02, 1.0938e-01, 3.1445e-01,\n", " 3.7109e-01, -4.8340e-02, 2.7734e-01, 1.2756e-02, -9.6191e-02,\n", " -2.0312e-01, -1.5527e-01, 1.1035e-01, -6.5430e-02, -2.7539e-01,\n", " -2.3438e-01, -3.7891e-01, -8.3008e-02, -1.2500e-01, -1.4062e-01,\n", " -9.2773e-03, -4.2188e-01, -1.1719e-01, -2.4121e-01, 3.4424e-02,\n", " 5.3406e-03, -2.9883e-01, 3.4570e-01, -6.5430e-02, 5.5420e-02,\n", " -5.1758e-02, -1.1279e-01, 2.0117e-01, 2.9785e-02, 1.0547e-01,\n", " 2.7539e-01, 1.0205e-01, -1.6699e-01, -7.0801e-02, -2.9688e-01,\n", " 1.5039e-01, 2.9492e-01, -1.9727e-01, 4.8584e-02, -4.2773e-01,\n", " 1.5564e-02, 1.2061e-01, -1.7090e-01, 6.6895e-02, -1.5039e-01,\n", " -5.2734e-02, 6.4941e-02, 3.7842e-02, 1.7456e-02, 1.4453e-01,\n", " -1.4746e-01, 1.0742e-01, -1.9629e-01, -5.2002e-02, -6.2988e-02,\n", " -1.7188e-01, 2.2583e-02, -5.2795e-03, -1.2256e-01, -6.2500e-02,\n", " 8.8501e-04, 1.8433e-02, 1.1279e-01, -1.1133e-01, 1.0073e-05,\n", " 2.4121e-01, 6.6406e-02, -2.4536e-02, 1.8921e-02, -1.8311e-02,\n", " 3.0664e-01, 1.8311e-02, -3.8574e-02, 1.8164e-01, 2.2736e-03,\n", " 1.4551e-01, 6.7383e-02, -1.6846e-02, 1.8750e-01, -1.2109e-01,\n", " 2.1729e-02, 1.2988e-01, 1.5527e-01, -7.8125e-02, 5.0537e-02,\n", " -7.6660e-02, -1.1816e-01, 4.5166e-02, -3.7109e-02, 2.4805e-01,\n", " -3.1055e-01, 1.9824e-01, -2.3730e-01, -2.1851e-02, -9.7656e-02,\n", " -1.0803e-02, 1.6992e-01, 1.1377e-01, -5.1758e-02, -1.9434e-01,\n", " 9.4727e-02, 1.2891e-01, 2.0508e-02, -2.7734e-01, -8.6914e-02,\n", " 1.2634e-02, -4.5703e-01, -1.3379e-01, -8.1543e-02, 2.8906e-01,\n", " 1.8945e-01, 1.6211e-01, -1.6479e-02, 3.7537e-03, -1.3867e-01,\n", " -1.0498e-01, -1.2012e-01, -1.1353e-02, 5.1514e-02, 4.6875e-02,\n", " -8.9355e-02, 2.5635e-02, -2.7734e-01, 1.0547e-01, -1.0303e-01,\n", " 2.0703e-01, -1.6797e-01, 1.9922e-01, 1.6724e-02, -1.8066e-01,\n", " 1.0645e-01, -3.3984e-01, -5.8105e-02, -3.3594e-01, 4.3945e-02,\n", " 8.5938e-02, 4.4434e-02, 2.7466e-03, 1.7090e-01, 7.4219e-02,\n", " -3.0640e-02, -7.8613e-02, 5.0354e-03, -1.1670e-01, 3.1836e-01,\n", " 2.6367e-01, -9.4727e-02, 1.2158e-01, 1.5234e-01, -2.3340e-01,\n", " 3.8477e-01, -1.1328e-01, 4.1504e-02, -2.2070e-01, 1.0559e-02,\n", " -3.4180e-02, 7.5684e-02, -2.3633e-01, -6.9824e-02, 3.0396e-02,\n", " -2.0020e-01, 1.7480e-01, -1.5723e-01, -3.9551e-02, -2.6953e-01,\n", " -1.1182e-01, 1.0498e-01, 1.1475e-01, -2.4796e-04, -1.4258e-01,\n", " -1.6309e-01, -2.6367e-01, 1.4453e-01, 1.6309e-01, 2.1973e-02,\n", " -2.0605e-01, -3.1738e-02, -1.5625e-01, -1.0938e-01, 4.5703e-01,\n", " 2.7148e-01, 6.1279e-02, 2.0142e-03, -1.2158e-01, 1.5820e-01]])" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embs = nn.Embedding( len(word_emb.key_to_index), 300).from_pretrained(torch.from_numpy(word_emb.vectors))\n", "idx = torch.LongTensor([word_emb.key_to_index[word] for word in [\"soccer\", \"tennis\", \"football\"]])\n", "embs(idx)" ] }, { "cell_type": "markdown", "id": "7966f74b-f1a0-4def-9519-4b833f68eda5", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "### It's your turn! Go ahead with *Task 2.*\n", "\n", "" ] }, { "cell_type": "markdown", "id": "b47ccc12-03e7-4eb6-95cf-b20f445b18fe", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "#### Task 2.1\n", "\n", "- Plot the 3D PCA representation of the sentence embeddings, using a different color for each label (ham and spam)\n", " - Are the two classes neatly separated?" ] }, { "cell_type": "markdown", "id": "3a1117ab-7c06-475a-a341-1d389b80f30d", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "#### Task 2.2\n", "\n", "1. Implement a new classifier to solve this task.\n", " - You can use:\n", " - a ```scikit-learn``` estimator, for example: [KNeighborsClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html), [SVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html#sklearn.svm.SVC), [DecisionTreeClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier)\n", " - **only** if you are confident with your pytorch skills, you can implement either:\n", " - a feed-forward network that processes the sentence embeddings\n", " - a RNN of your choice (LSTM, GRU, BiLSTM) that processes one word per time step \n", " - **Requirements**:\n", " - use the same train_test split done above (```sent_emb_train``` / ```sent_emb_test```)\n", "2. Evaluate your model on the test set using classification_report" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }