{"cells":[{"cell_type":"raw","metadata":{"id":"sL1vUhyl6mX-"},"source":["# Natural Language Processing Tutorial 4 - Summarization\n","Francesco Cazzaro, Universitat Politècnica de Catalunya"]},{"cell_type":"markdown","metadata":{"id":"FfgUlUHYfjeN"},"source":["##Introduction\n","\n","In this notebook, we will explore how to train and test a transformer-based model for automatic summarization using the powerful Hugging Face libraries.\n","\n","Text summarization is a challenging task in the field of Natural Language Processing, aiming to condense lengthy pieces of text into shorter summaries while preserving the most important information. It finds numerous applications in areas such as news summarization, document summarization, and information retrieval."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vBdAEm6jKdc-","outputId":"4f83ca5e-2b7d-4679-a78b-1a70c981d1ae"},"outputs":[{"name":"stdout","output_type":"stream","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting datasets\n"," Downloading datasets-2.12.0-py3-none-any.whl (474 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m474.6/474.6 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.22.4)\n","Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n","Collecting dill<0.3.7,>=0.3.0 (from datasets)\n"," Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n","Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.27.1)\n","Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.65.0)\n","Collecting xxhash (from datasets)\n"," Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.5/212.5 kB\u001b[0m \u001b[31m16.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting multiprocess (from datasets)\n"," Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.3/134.3 kB\u001b[0m \u001b[31m12.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.4.0)\n","Collecting aiohttp (from datasets)\n"," Downloading aiohttp-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m18.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting huggingface-hub<1.0.0,>=0.11.0 (from datasets)\n"," Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m224.5/224.5 kB\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.1)\n","Collecting responses<0.19 (from datasets)\n"," Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n","Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.0.12)\n","Collecting multidict<7.0,>=4.5 (from aiohttp->datasets)\n"," Downloading multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.5/114.5 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting async-timeout<5.0,>=4.0.0a3 (from aiohttp->datasets)\n"," Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n","Collecting yarl<2.0,>=1.0 (from aiohttp->datasets)\n"," Downloading yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (268 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m17.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting frozenlist>=1.1.1 (from aiohttp->datasets)\n"," Downloading frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (149 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m149.6/149.6 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting aiosignal>=1.1.2 (from aiohttp->datasets)\n"," Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (3.12.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (4.5.0)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2022.7.1)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n","Installing collected packages: xxhash, multidict, frozenlist, dill, async-timeout, yarl, responses, multiprocess, huggingface-hub, aiosignal, aiohttp, datasets\n","Successfully installed aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 datasets-2.12.0 dill-0.3.6 frozenlist-1.3.3 huggingface-hub-0.14.1 multidict-6.0.4 multiprocess-0.70.14 responses-0.18.0 xxhash-3.2.0 yarl-1.9.2\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting transformers==4.28.0\n"," Downloading transformers-4.28.0-py3-none-any.whl (7.0 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m24.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.28.0) (3.12.0)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.28.0) (0.14.1)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.28.0) (1.22.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.28.0) (23.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.28.0) (6.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.28.0) (2022.10.31)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.28.0) (2.27.1)\n","Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.28.0)\n"," Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m122.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.28.0) (4.65.0)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers==4.28.0) (2023.4.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers==4.28.0) (4.5.0)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.28.0) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.28.0) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.28.0) (2.0.12)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.28.0) (3.4)\n","Installing collected packages: tokenizers, transformers\n","Successfully installed tokenizers-0.13.3 transformers-4.28.0\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting py7zr\n"," Downloading py7zr-0.20.5-py3-none-any.whl (66 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.4/66.4 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting texttable (from py7zr)\n"," Downloading texttable-1.6.7-py2.py3-none-any.whl (10 kB)\n","Collecting pycryptodomex>=3.6.6 (from py7zr)\n"," Downloading pycryptodomex-3.17-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m39.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pyzstd>=0.14.4 (from py7zr)\n"," Downloading pyzstd-0.15.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (399 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m399.3/399.3 kB\u001b[0m \u001b[31m44.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pyppmd<1.1.0,>=0.18.1 (from py7zr)\n"," Downloading pyppmd-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (138 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.8/138.8 kB\u001b[0m \u001b[31m18.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pybcj>=0.6.0 (from py7zr)\n"," Downloading pybcj-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (49 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.8/49.8 kB\u001b[0m \u001b[31m292.6 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting multivolumefile>=0.2.3 (from py7zr)\n"," Downloading multivolumefile-0.2.3-py3-none-any.whl (17 kB)\n","Collecting brotli>=1.0.9 (from py7zr)\n"," Downloading Brotli-1.0.9-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (2.7 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m97.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting inflate64>=0.3.1 (from py7zr)\n"," Downloading inflate64-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (93 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m93.1/93.1 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from py7zr) (5.9.5)\n","Installing collected packages: texttable, brotli, pyzstd, pyppmd, pycryptodomex, pybcj, multivolumefile, inflate64, py7zr\n","Successfully installed brotli-1.0.9 inflate64-0.3.1 multivolumefile-0.2.3 py7zr-0.20.5 pybcj-1.0.1 pycryptodomex-3.17 pyppmd-1.0.0 pyzstd-0.15.7 texttable-1.6.7\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting evaluate\n"," Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.4/81.4 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (2.12.0)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from evaluate) (1.22.4)\n","Requirement already satisfied: dill in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.3.6)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from evaluate) (1.5.3)\n","Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (2.27.1)\n","Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from evaluate) (4.65.0)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from evaluate) (3.2.0)\n","Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.70.14)\n","Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (2023.4.0)\n","Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.14.1)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from evaluate) (23.1)\n","Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.18.0)\n","Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (9.0.0)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (3.8.4)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (6.0)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.7.0->evaluate) (3.12.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.5.0)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (2.0.12)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (3.4)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->evaluate) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->evaluate) (2022.7.1)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (23.1.0)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.0.4)\n","Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (4.0.2)\n","Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.9.2)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.3)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->evaluate) (1.16.0)\n","Installing collected packages: evaluate\n","Successfully installed evaluate-0.4.0\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting rouge_score\n"," Downloading rouge_score-0.1.2.tar.gz (17 kB)\n"," Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.4.0)\n","Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge_score) (3.8.1)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.22.4)\n","Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.16.0)\n","Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (8.1.3)\n","Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (1.2.0)\n","Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (2022.10.31)\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (4.65.0)\n","Building wheels for collected packages: rouge_score\n"," Building wheel for rouge_score (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=9fb6bd99909f4ad4ad2dc35ced84aaf5798feae63fb8ca26bddeb7c1fc8e22cd\n"," Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4\n","Successfully built rouge_score\n","Installing collected packages: rouge_score\n","Successfully installed rouge_score-0.1.2\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting accelerate\n"," Downloading accelerate-0.19.0-py3-none-any.whl (219 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m219.1/219.1 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.22.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.1)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0)\n","Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.0+cu118)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (3.12.0)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (4.5.0)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (1.11.1)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (3.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (3.1.2)\n","Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (2.0.0)\n","Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->accelerate) (3.25.2)\n","Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->accelerate) (16.0.3)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.6.0->accelerate) (2.1.2)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.6.0->accelerate) (1.3.0)\n","Installing collected packages: accelerate\n","Successfully installed accelerate-0.19.0\n"]}],"source":["!pip install datasets \n","!pip install transformers==4.28.0\n","!pip install py7zr\n","!pip install evaluate\n","!pip install rouge_score\n","!pip install accelerate"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"fXpE3XkFMagl"},"outputs":[],"source":["from datasets import load_dataset\n","from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq\n","import evaluate\n","import numpy as np"]},{"cell_type":"markdown","metadata":{"id":"RKgsD2xlahjV"},"source":["##Model\n","\n","We use the T5 model which is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which each task is converted into a text-to-text format. In particular summarization was also included in the pre-training.\n","\n","It's highly recommended to check out the official page [here](https://huggingface.co/docs/transformers/model_doc/t5). In general the documentation pages are a great resource that contain detailed explanations and tons of useful information on how to use the models."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"i64bnHhAajVU"},"outputs":[],"source":["MODEL_NAME = \"t5-small\"\n","\n","tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n","model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)"]},{"cell_type":"markdown","metadata":{"id":"Q3tR3evrfvyt"},"source":["Let's take a look at the model definition."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"BP7wLxzQfxcL","outputId":"eb720c6d-d692-47ae-fe8c-87f97be034cf"},"outputs":[{"data":{"text/plain":["T5ForConditionalGeneration(\n"," (shared): Embedding(32128, 512)\n"," (encoder): T5Stack(\n"," (embed_tokens): Embedding(32128, 512)\n"," (block): ModuleList(\n"," (0): T5Block(\n"," (layer): ModuleList(\n"," (0): T5LayerSelfAttention(\n"," (SelfAttention): T5Attention(\n"," (q): Linear(in_features=512, out_features=512, bias=False)\n"," (k): Linear(in_features=512, out_features=512, bias=False)\n"," (v): Linear(in_features=512, out_features=512, bias=False)\n"," (o): Linear(in_features=512, out_features=512, bias=False)\n"," (relative_attention_bias): Embedding(32, 8)\n"," )\n"," (layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (1): T5LayerFF(\n"," (DenseReluDense): T5DenseActDense(\n"," (wi): Linear(in_features=512, out_features=2048, bias=False)\n"," (wo): Linear(in_features=2048, out_features=512, bias=False)\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," (act): ReLU()\n"," )\n"," (layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," )\n"," )\n"," (1-5): 5 x T5Block(\n"," (layer): ModuleList(\n"," (0): T5LayerSelfAttention(\n"," (SelfAttention): T5Attention(\n"," (q): Linear(in_features=512, out_features=512, bias=False)\n"," (k): Linear(in_features=512, out_features=512, bias=False)\n"," (v): Linear(in_features=512, out_features=512, bias=False)\n"," (o): Linear(in_features=512, out_features=512, bias=False)\n"," )\n"," (layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (1): T5LayerFF(\n"," (DenseReluDense): T5DenseActDense(\n"," (wi): Linear(in_features=512, out_features=2048, bias=False)\n"," (wo): Linear(in_features=2048, out_features=512, bias=False)\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," (act): ReLU()\n"," )\n"," (layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," )\n"," )\n"," )\n"," (final_layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (decoder): T5Stack(\n"," (embed_tokens): Embedding(32128, 512)\n"," (block): ModuleList(\n"," (0): T5Block(\n"," (layer): ModuleList(\n"," (0): T5LayerSelfAttention(\n"," (SelfAttention): T5Attention(\n"," (q): Linear(in_features=512, out_features=512, bias=False)\n"," (k): Linear(in_features=512, out_features=512, bias=False)\n"," (v): Linear(in_features=512, out_features=512, bias=False)\n"," (o): Linear(in_features=512, out_features=512, bias=False)\n"," (relative_attention_bias): Embedding(32, 8)\n"," )\n"," (layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (1): T5LayerCrossAttention(\n"," (EncDecAttention): T5Attention(\n"," (q): Linear(in_features=512, out_features=512, bias=False)\n"," (k): Linear(in_features=512, out_features=512, bias=False)\n"," (v): Linear(in_features=512, out_features=512, bias=False)\n"," (o): Linear(in_features=512, out_features=512, bias=False)\n"," )\n"," (layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (2): T5LayerFF(\n"," (DenseReluDense): T5DenseActDense(\n"," (wi): Linear(in_features=512, out_features=2048, bias=False)\n"," (wo): Linear(in_features=2048, out_features=512, bias=False)\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," (act): ReLU()\n"," )\n"," (layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," )\n"," )\n"," (1-5): 5 x T5Block(\n"," (layer): ModuleList(\n"," (0): T5LayerSelfAttention(\n"," (SelfAttention): T5Attention(\n"," (q): Linear(in_features=512, out_features=512, bias=False)\n"," (k): Linear(in_features=512, out_features=512, bias=False)\n"," (v): Linear(in_features=512, out_features=512, bias=False)\n"," (o): Linear(in_features=512, out_features=512, bias=False)\n"," )\n"," (layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (1): T5LayerCrossAttention(\n"," (EncDecAttention): T5Attention(\n"," (q): Linear(in_features=512, out_features=512, bias=False)\n"," (k): Linear(in_features=512, out_features=512, bias=False)\n"," (v): Linear(in_features=512, out_features=512, bias=False)\n"," (o): Linear(in_features=512, out_features=512, bias=False)\n"," )\n"," (layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (2): T5LayerFF(\n"," (DenseReluDense): T5DenseActDense(\n"," (wi): Linear(in_features=512, out_features=2048, bias=False)\n"," (wo): Linear(in_features=2048, out_features=512, bias=False)\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," (act): ReLU()\n"," )\n"," (layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," )\n"," )\n"," )\n"," (final_layer_norm): T5LayerNorm()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (lm_head): Linear(in_features=512, out_features=32128, bias=False)\n",")"]},"execution_count":4,"metadata":{},"output_type":"execute_result"}],"source":["model"]},{"cell_type":"markdown","metadata":{"id":"Jthj-_IoKkST"},"source":["## Dataset\n","\n","For this notebook we use the [samsum](https://huggingface.co/datasets/samsum) dataset. It contains 16k messenger-like conversations with annotated summaries.\n","\n","\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"sOQpqLVmQZzJ"},"outputs":[],"source":["train_data = load_dataset(\"samsum\", split=\"train\")\n","val_data = load_dataset(\"samsum\", split=\"validation\")\n","test_data = load_dataset(\"samsum\", split=\"test\")"]},{"cell_type":"markdown","metadata":{"id":"5k8HqncMYOyG"},"source":["Let's take a look at an example."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"avJ9ZJKjYYQP","outputId":"7d8c5afe-d5d2-4671-c179-c204462ad767"},"outputs":[{"name":"stdout","output_type":"stream","text":["{'id': '13818513', 'dialogue': \"Amanda: I baked cookies. Do you want some?\\r\\nJerry: Sure!\\r\\nAmanda: I'll bring you tomorrow :-)\", 'summary': 'Amanda baked cookies and will bring Jerry some tomorrow.'}\n"]}],"source":["print(train_data[0])"]},{"cell_type":"markdown","metadata":{"id":"4iSsniOzdN_A"},"source":["\n","Next we perform the necessary data preprocessing steps.\n","\n","First, as mentioned in the model documentation, we prepend the prefix 'summarize' to each input sample. This prefix acts as a prompt for the model, indicating that the task is text summarization. Then we use the tokenizer to get the input ids for both the text and the target summaries.\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ZwAkc4wJdZi-"},"outputs":[],"source":["PREFIX = 'summarize: '\n","\n","def preprocess_function(samples):\n","\n"," inputs = [PREFIX + text for text in samples['dialogue']]\n"," inputs = tokenizer(inputs, max_length=512, truncation=True)\n","\n"," with tokenizer.as_target_tokenizer():\n"," labels = tokenizer(samples['summary'], max_length=128, truncation=True)\n"," inputs['labels'] = labels.input_ids\n","\n"," return inputs"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"PUF70VVVdOwh"},"outputs":[],"source":["train_data = train_data.map(preprocess_function, batched=True)\n","val_data = val_data.map(preprocess_function, batched=True)\n","test_data = test_data.map(preprocess_function, batched=True)"]},{"cell_type":"markdown","metadata":{"id":"K8It4ZnIkADr"},"source":["Then, we define our data collator that will automatically do dynamic padding for us."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JZGqXVzDkCZc"},"outputs":[],"source":["data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=MODEL_NAME)"]},{"cell_type":"markdown","metadata":{"id":"4DJcmniokzjv"},"source":["## Evaluation\n","\n","As evaluation metric we use ROUGE (Recall-Oriented Understudy for Gisting Evaluation). It measures the overlap between the generated summary and one or more reference summaries. The key idea behind ROUGE is to capture the recall of important information in the generated summary by comparing it with the reference summaries. [Here](https://medium.com/nlplanet/two-minutes-nlp-learn-the-rouge-metric-by-examples-f179cc285499) is a link with a brief explanation."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"eJ86jbqwk6IV"},"outputs":[],"source":["rouge = evaluate.load(\"rouge\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"8Aw6KDWplJL_"},"outputs":[],"source":["def compute_metrics(eval_pred):\n"," predictions, labels = eval_pred\n"," decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n"," labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n"," decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n","\n"," result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)\n","\n"," prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]\n"," result['gen_len'] = np.mean(prediction_lens)\n","\n"," return {k: round(v, 4) for k, v in result.items()}"]},{"cell_type":"markdown","metadata":{"id":"hOtp8PYUly0R"},"source":["##TRAINING\n","\n","We are now ready to train our model. Let's setup the training parameters and call `trainer.train()`."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":433},"id":"8rY5PAowl0f9","outputId":"3d95df82-38bf-4267-9ace-1d9b6ac36bd5","scrolled":true},"outputs":[{"name":"stderr","output_type":"stream","text":["/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n"," warnings.warn(\n","You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"]},{"data":{"text/html":["\n","
\n"," \n"," \n"," [3688/3688 1:00:18, Epoch 8/8]\n","
\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
EpochTraining LossValidation LossRouge1Rouge2RougelRougelsumGen Len
1No log1.7301380.4217000.1937000.3541000.35400016.803200
21.9818001.6749810.4409000.2100000.3688000.36850016.677300
31.7810001.6661990.4496000.2222000.3817000.38160016.469400
41.6751001.6544120.4511000.2236000.3798000.37980016.757900
51.6103001.6467680.4562000.2256000.3849000.38480016.636900
61.5438001.6470680.4591000.2279000.3856000.38560016.651600
71.5063001.6457390.4619000.2284000.3879000.38810016.749400
81.4763001.6483650.4609000.2275000.3877000.38810016.813000

"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/plain":["TrainOutput(global_step=3688, training_loss=1.6438676002485892, metrics={'train_runtime': 3622.4301, 'train_samples_per_second': 32.535, 'train_steps_per_second': 1.018, 'total_flos': 1.3623539571621888e+16, 'train_loss': 1.6438676002485892, 'epoch': 8.0})"]},"execution_count":12,"metadata":{},"output_type":"execute_result"}],"source":["training_args = Seq2SeqTrainingArguments(\n"," output_dir = 'models/',\n"," evaluation_strategy='epoch',\n"," learning_rate=3e-4,\n"," per_device_train_batch_size=32,\n"," per_device_eval_batch_size=32,\n"," weight_decay=0.01,\n"," num_train_epochs=8,\n"," predict_with_generate=True,\n"," metric_for_best_model='rouge1',\n"," load_best_model_at_end=True,\n"," save_strategy='epoch'\n",")\n","\n","trainer = Seq2SeqTrainer(\n"," model=model,\n"," args=training_args,\n"," train_dataset=train_data,\n"," eval_dataset=val_data,\n"," tokenizer=tokenizer,\n"," data_collator=data_collator,\n"," compute_metrics=compute_metrics,\n",")on.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed i\n","\n","trainer.train()"]},{"cell_type":"markdown","metadata":{"id":"r1tV69adqFPD"},"source":["Once the training is done let's compute the predictions on the test set and evaluate the performance"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":54},"id":"pJ0achClCEN2","outputId":"17e6a365-1e78-491f-a05b-1a53b7c74185","scrolled":true},"outputs":[{"data":{"text/html":[],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["{'test_loss': 1.656678318977356, 'test_rouge1': 0.4409, 'test_rouge2': 0.2039, 'test_rougeL': 0.3678, 'test_rougeLsum': 0.3675, 'test_gen_len': 16.9976, 'test_runtime': 22.8167, 'test_samples_per_second': 35.895, 'test_steps_per_second': 1.14}\n"]}],"source":["preds = trainer.predict(test_data)\n","print(preds.metrics)"]},{"cell_type":"markdown","metadata":{"id":"Y6zMibyss9VD"},"source":["Let's also inspect manually one of the predictions"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tpGy4nZTCOYR","outputId":"43050bae-c971-4192-d6e8-508ec9470bc9","scrolled":true},"outputs":[{"name":"stdout","output_type":"stream","text":["Jack: Cocktails later?\r\n","May: YES!!!\r\n","May: You read my mind...\r\n","Jack: Possibly a little tightly strung today?\r\n","May: Sigh... without question.\r\n","Jack: Thought so.\r\n","May: A little drink will help!\r\n","Jack: Maybe two!\n","\n","Gold summary:\n","Jack and May will drink cocktails later.\n","\n","Generated summary:\n","Jack and May will have a drink together.\n"]}],"source":["sample_id = 32\n","decoded_p = tokenizer.decode(preds.predictions[sample_id], skip_special_tokens=True)\n","print(test_data[sample_id]['dialogue'])\n","print('\\nGold summary:')\n","print(test_data[sample_id]['summary'])\n","print('\\nGenerated summary:')\n","print(decoded_p)"]}],"metadata":{"accelerator":"GPU","celltoolbar":"Slideshow","colab":{"gpuType":"T4","provenance":[]},"gpuClass":"standard","kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.8"}},"nbformat":4,"nbformat_minor":0}