{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "85e84869",
      "metadata": {},
      "source": [
        "# CT LAB: Simulating and Reconstructiong Tomography Data\n",
        "\n",
        "Authors: L. Calatroni, A. Sebastiani (MaLGa, Unige)\n",
        "\n",
        "Mini-course \"Computational Imaging & Learning\" - MSc in Data Science, University of Padua, Italy."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "67b1281e-b7de-491a-b3b3-7572eaecc4d7",
      "metadata": {
        "id": "67b1281e-b7de-491a-b3b3-7572eaecc4d7"
      },
      "source": [
        "In the previous lab, we solved the inverse problem using **Model-Based Reconstruction**. \n",
        "We explicitly defined a regularization function and minimized a cost function. \n",
        "While mathematically rigorous, these hand-crafted priors often fail to capture the complex, high-frequency textures.\n",
        "\n",
        "In this lab, we transition to **learned regularisation functionals**. Instead of manually defining the regularizaer, we will learn the prior distribution implicitly from data.\n",
        "\n",
        "We will explore **Plug-and-Play (PnP) Priors**: We replace the proximal operator of the regularizer with a pre-trained deep denoising neural network (e.g., DRUNet). This is a hybrid approach that combines known physics with learned image statistics."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "30a64f93-690e-4df9-b2a6-b805923c9eee",
      "metadata": {
        "id": "30a64f93-690e-4df9-b2a6-b805923c9eee"
      },
      "outputs": [],
      "source": [
        "# Install deepinv (and ptwt, in case you need to use wavelets)\n",
        "!pip install ptwt\n",
        "!pip install natsort\n",
        "!pip install deepinv"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "af6063fb-5d4a-42f4-99fc-5ee61ccd8960",
      "metadata": {
        "id": "af6063fb-5d4a-42f4-99fc-5ee61ccd8960"
      },
      "outputs": [],
      "source": [
        "import deepinv as dinv\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from torch.utils.data import DataLoader\n",
        "import torch\n",
        "from pathlib import Path\n",
        "from torchvision import transforms, datasets\n",
        "\n",
        "from deepinv.optim.prior import Prior, PnP\n",
        "from deepinv.optim.optimizers import optim_builder\n",
        "from deepinv.training import test\n",
        "from deepinv.models import DnCNN\n",
        "\n",
        "# Set the global random seed from pytorch to ensure reproducibility of the example.\n",
        "torch.manual_seed(42)\n",
        "\n",
        "# Specify the device (to use GPU on colab, first change the runtime to T4 GPU)\n",
        "device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# Use parallel dataloader if using a GPU to fasten training.\n",
        "num_workers = 5 if torch.cuda.is_available() else 0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "wQXxzNnP1E3G",
      "metadata": {
        "id": "wQXxzNnP1E3G"
      },
      "outputs": [],
      "source": [
        "class GaussianNoiseCT(dinv.physics.GaussianNoise):\n",
        "  def __init__(\n",
        "        self,\n",
        "        sigma: float | torch.Tensor = 0.1,\n",
        "        rng: torch.Generator | None = None,\n",
        "    ):\n",
        "        super().__init__(sigma=sigma, rng=rng)\n",
        "\n",
        "  def forward(self, x, sigma=None, seed=None, **kwargs):\n",
        "        self.update_parameters(sigma=sigma, **kwargs)\n",
        "        self.to(x.device)\n",
        "        return (\n",
        "            x\n",
        "            + torch.max(torch.abs(x))*self.randn_like(x, seed=seed)\n",
        "            * self.sigma[(...,) + (None,) * (x.dim() - 1)]\n",
        "        )"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "20ad4a6a-d13d-4ece-8217-400210dc1b34",
      "metadata": {
        "id": "20ad4a6a-d13d-4ece-8217-400210dc1b34"
      },
      "source": [
        "# **Create the Physics: Radon Transform - Computed Tomography**\n",
        "\n",
        "$$ \\begin{aligned} \\mathbf{y}(\\boldsymbol{\\theta},\\rho) &= \\int_{L_{\\boldsymbol{\\theta},\\rho}}~u~d\\sigma\n",
        "\\end{aligned} \\qquad \\boldsymbol{\\theta} \\in \\R^2,~ \\| \\boldsymbol{\\theta}\\|=1, \\quad \\rho \\in \\mathbb{R}_+$$\n",
        "\n",
        "($\\mathbf{y}$ is usually referred to as the *sinogram* of $u$ and $L_{\\boldsymbol{\\theta},\\rho} =  \\left\\{ \\mathbf{x}: \\mathbf{x}\\cdot \\boldsymbol{\\theta} = \\rho \\right\\}$)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "717b9890-eeaf-49c9-8eee-b273d4a3e4cf",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "717b9890-eeaf-49c9-8eee-b273d4a3e4cf",
        "outputId": "3275d448-41ae-4a21-dbd2-5b6bce129260"
      },
      "outputs": [],
      "source": [
        "\n",
        "noise_level_img = 0.1  # Gaussian Noise standard deviation for the degradation\n",
        "\n",
        "angles = # TO DO: define a small number of angles to be considered\n",
        "\n",
        "physics = dinv.physics.Tomography(\n",
        "    img_width=28,\n",
        "    angles=angles,\n",
        "    circle=False,\n",
        "    device=device,\n",
        "    noise_model=GaussianNoiseCT(sigma=noise_level_img)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "6d8e8723-ab81-4e88-b5f9-940c3e2ac078",
      "metadata": {
        "id": "6d8e8723-ab81-4e88-b5f9-940c3e2ac078"
      },
      "source": [
        "# **Create a dataset**\n",
        "\n",
        "Next, we create a supervised dataset for the tomography problem by constructing\n",
        "$$ \\{(y_i,u_i)\\}_{i=1}^N, \\qquad u_i \\text{: an image from dataset,}\\quad y_i = \\mathcal{R}u_i +\\epsilon_i $$\n",
        "\n",
        "Hint: use a rather small sample size $N$ to guarantee reasonable training times."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1700a761-985c-4667-afac-a879a5c9feb9",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1700a761-985c-4667-afac-a879a5c9feb9",
        "outputId": "64d3ed58-e280-4bb3-c833-6be9b595613d"
      },
      "outputs": [],
      "source": [
        "# Import the MNIST dataset\n",
        "transform = transforms.Compose([transforms.ToTensor()])\n",
        "\n",
        "Train_dataset = datasets.MNIST(root=\"../datasets/\", train=True, transform=transform, download=True)\n",
        "Test_dataset = datasets.MNIST(root=\"../datasets/\", train=False, transform=transform, download=True)\n",
        "\n",
        "# Create a supervised dataset of simulated measurements\n",
        "\n",
        "# Specify the (maximum) size of the train and test sets\n",
        "n_train_max = (250 if torch.cuda.is_available() else 50)  # number of images used for training\n",
        "n_test_max = (50 if torch.cuda.is_available() else 10)  # number of images used for testing\n",
        "\n",
        "# Set the path to save the datasets\n",
        "BASE_DIR = Path(\".\")\n",
        "measurement_dir = BASE_DIR / \"dataset\"\n",
        "\n",
        "deepinv_datasets_path = dinv.datasets.generate_dataset(\n",
        "    train_dataset=Train_dataset,\n",
        "    test_dataset=Test_dataset,\n",
        "    physics=physics,\n",
        "    device=device,\n",
        "    save_dir=measurement_dir,\n",
        "    train_datapoints=n_train_max,\n",
        "    test_datapoints=n_test_max,\n",
        "    num_workers=num_workers,\n",
        "    dataset_filename=\"tomo\"\n",
        ")\n",
        "\n",
        "train_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=True)\n",
        "test_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a903ea17-5b71-40d5-aa25-46542b85dcc4",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 413
        },
        "id": "a903ea17-5b71-40d5-aa25-46542b85dcc4",
        "outputId": "3cfc3a9b-b334-4244-bdd8-a4d524a4dfb3"
      },
      "outputs": [],
      "source": [
        "# Vizualize sample images from the dataset\n",
        "select_image = 42\n",
        "dinv.utils.plot([train_dataset[select_image][0], train_dataset[select_image][1]], ['GT', 'sinogram'], figsize=(8,4), cbar=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "301a293a-b734-4c2b-bc5b-21d8a4ca7056",
      "metadata": {
        "id": "301a293a-b734-4c2b-bc5b-21d8a4ca7056"
      },
      "source": [
        "# **Variational regularisation techniques**\n",
        "\n",
        "Choose a regularisation functional $\\phi: \\mathbb{R}^{n^2}\\rightarrow \\mathbb{R}$ and a parameter $\\lambda >0 $ and solve\n",
        "$$ u_\\lambda \\in \\arg\\min_{u \\in \\mathbb{R}^{n^2}} \\left\\{ \\frac{1}{2}\\| A u-y\\|^2 + \\lambda \\phi(u)\\right\\}$$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6ebc83b0-bfc1-4b9d-93d6-df396c61641e",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 439
        },
        "id": "6ebc83b0-bfc1-4b9d-93d6-df396c61641e",
        "outputId": "ae00f16b-61af-4ccb-9764-374e2e230f3b"
      },
      "outputs": [],
      "source": [
        "# Select the data fidelity term (the first part of the functional to be minimized)\n",
        "data_fidelity = dinv.optim.data_fidelity.L2()\n",
        "\n",
        "# Specify the custom prior R\n",
        "prior =  # TO DO\n",
        "\n",
        "# Specific parameters for restoration with the given prior (Note that these parameters have not been optimized here)\n",
        "params_algo = {\"stepsize\": 1., \"lambda\": 0.05}\n",
        "\n",
        "# Instantiate the algorithm class to solve the IP problem\n",
        "modelVAR = optim_builder(\n",
        "    iteration=\"PGD\", # proximal gradient descent\n",
        "    prior=prior,\n",
        "    g_first=False,\n",
        "    data_fidelity=data_fidelity,\n",
        "    params_algo=params_algo,\n",
        "    early_stop=True,\n",
        "    max_iter=500,\n",
        "    crit_conv=\"cost\",\n",
        "    thres_conv=1e-5,\n",
        "    backtracking=False,\n",
        "    verbose=False,\n",
        ")\n",
        "\n",
        "# To get its 'average performances' we apply it to all the elements of the test set and take an average\n",
        "\n",
        "batch_size = 1\n",
        "var_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)\n",
        "\n",
        "test(\n",
        "    model=modelVAR,\n",
        "    test_dataloader=var_dataloader,\n",
        "    physics=physics,\n",
        "    device=device,\n",
        "    verbose=True,\n",
        "    plot_images = True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a6523ae8-fc48-494d-b485-4fb2283c30c6",
      "metadata": {
        "id": "a6523ae8-fc48-494d-b485-4fb2283c30c6"
      },
      "source": [
        "**********************************************************\n",
        "**TASK 2: Explore other regularization choices**\n",
        "\n",
        "Explore alternative choices for $\\phi$. Have a look at https://deepinv.github.io/deepinv/api/stubs/deepinv.optim.Prior.html#deepinv.optim.Prior\n",
        "\n",
        "\n",
        "**********************************************************"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "_L1PE3UqRkno",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 384
        },
        "id": "_L1PE3UqRkno",
        "outputId": "7d15e736-c2ce-4574-9ec1-0b32e8543c8b"
      },
      "outputs": [],
      "source": [
        "# Select the data fidelity term (the first part of the functional to be minimized)\n",
        "data_fidelity = dinv.optim.data_fidelity.L2()\n",
        "\n",
        "# Check the TV prior (check https://deepinv.github.io/deepinv/auto_examples/optimization/demo_TV_minimisation.html )\n",
        "prior = # TO DO\n",
        "\n",
        "# Specific parameters for restoration with the given prior (Note that these parameters have not been optimized here)\n",
        "params_algo = # TO DO\n",
        "\n",
        "# Instantiate the algorithm class to solve the IP problem\n",
        "modelVAR = optim_builder(\n",
        "    iteration=\"PGD\", # proximal gradient descent\n",
        "    prior=prior,\n",
        "    g_first=False,\n",
        "    data_fidelity=data_fidelity,\n",
        "    params_algo=params_algo,\n",
        "    early_stop=True,\n",
        "    max_iter=500,\n",
        "    crit_conv=\"cost\",\n",
        "    thres_conv=1e-5,\n",
        "    backtracking=False,\n",
        "    verbose=False,\n",
        ")\n",
        "\n",
        "# To get its 'average performances' we apply it to all the elements of the test set and take an average\n",
        "\n",
        "batch_size = 1\n",
        "var_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)\n",
        "\n",
        "test(\n",
        "    model=modelVAR,\n",
        "    test_dataloader=var_dataloader,\n",
        "    physics=physics,\n",
        "    device=device,\n",
        "    verbose=True,\n",
        "    plot_images = True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "85b9d67e-119c-461e-b9ae-08b331548208",
      "metadata": {
        "id": "85b9d67e-119c-461e-b9ae-08b331548208"
      },
      "source": [
        "**********************************************************\n",
        "**TASK (Optional): Parameter Tuning**\n",
        "\n",
        "Find the best choice for the parameter $\\lambda$ by means of a supervised strategy.\n",
        "\n",
        "Define a pool of possible parameters $\\{10^{-7},10^{-6.5},\\ldots,10^{-1},10^{-0.5},1\\}$ and, for each of them, evaluate the performance on the training set. Then, pick the best one and use it on the test set.\n",
        "\n",
        "\n",
        "\n",
        "**********************************************************"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "1b36dfe5-6d24-4756-babd-c2a18c8a410a",
      "metadata": {
        "id": "1b36dfe5-6d24-4756-babd-c2a18c8a410a"
      },
      "source": [
        "## **Plug-and-Play 1: PGD with pre-trained denoiser**\n",
        "\n",
        "Consider the Proximal Gradient Descent (PGD) method associated with the minimization of the functional $\\frac{1}{2}\\| Au-y \\|^2 + \\lambda \\phi(f)$, namely\n",
        "$$\n",
        "\\left\\{\n",
        "\\begin{aligned}\n",
        "z^{(k+1)} &= u^{(k)} - \\tau A^\\top(Au^{(k)}-y) \\\\\n",
        "u^{(k+1)} &= \\operatorname{prox}_{\\tau \\lambda \\phi}(z^{(k+1)})\n",
        "\\end{aligned}\n",
        "\\right.\n",
        "$$\n",
        "and replace the proximal operator of $\\tau \\lambda \\phi$ by a neural network $D_{\\theta,\\sigma}$, obtaining\n",
        "$$\n",
        "\\left\\{\n",
        "\\begin{aligned}\n",
        "z^{(k+1)} &= u^{(k)} - \\tau A^\\top(Au^{(k)}-y) \\\\\n",
        "u^{(k+1)} &= D_{\\theta,\\sigma}(z^{(k+1)})\n",
        "\\end{aligned}\n",
        "\\right.\n",
        "$$\n",
        "The network $D_{\\theta,\\sigma}$ (depending on some parameters $\\theta$) plays the role of a denoiser, and in particular it is trained to remove Gaussian noise, i.e., to approximate the MMSE denoiser associated with the prior distribution corrupted by Gaussian noise with standard deviation $\\sigma$.\n",
        "\n",
        "Let us first consider a simple case in which $D_{\\theta,\\sigma}$ is a CNN that has been pre-trained on natural images."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3cbbcd4f-d7d4-467c-a8ed-6dfe2af1d6ff",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3cbbcd4f-d7d4-467c-a8ed-6dfe2af1d6ff",
        "outputId": "ce9dfc6b-69f5-48f6-8f53-bd620a66e470"
      },
      "outputs": [],
      "source": [
        "from deepinv.models import DnCNN\n",
        "\n",
        "sigma_PnP =  # TO BE TUNED\n",
        "params_algo = {\"stepsize\": 0.8, \"g_param\": sigma_PnP}\n",
        "max_iter = 200\n",
        "early_stop = True\n",
        "\n",
        "denoiser = DnCNN(\n",
        "    in_channels=1, # for greyscale images\n",
        "    out_channels=1,\n",
        "    pretrained=\"download\",  # try also \"download_lipschitz\", it has convergence guarantees, but less expressivity\n",
        "    device=device,\n",
        ")\n",
        "prior = PnP(denoiser=denoiser)\n",
        "\n",
        "data_fidelity = dinv.optim.data_fidelity.L2()\n",
        "\n",
        "\n",
        "# Instantiate the algorithm class to solve the IP problem\n",
        "modelPnP = optim_builder(\n",
        "    iteration=\"PGD\",\n",
        "    prior=prior,\n",
        "    data_fidelity=data_fidelity,\n",
        "    early_stop=early_stop,\n",
        "    max_iter=max_iter,\n",
        "    verbose=True,\n",
        "    params_algo=params_algo,\n",
        ")\n",
        "modelPnP.eval() # set the model to evaluation mode. We do not require training here\n",
        "\n",
        "\n",
        "# Set the data loader to test the regularizer\n",
        "batch_size = 1\n",
        "var_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)\n",
        "\n",
        "test(\n",
        "    model=modelPnP,\n",
        "    test_dataloader=var_dataloader,\n",
        "    physics=physics,\n",
        "    device=device,\n",
        "    verbose=True,\n",
        "    plot_images = True,\n",
        ")\n",
        "\n",
        "\n",
        "# What is the effect of a different choice of sigma?"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "576e7952-954a-4709-8940-c9528983d2b3",
      "metadata": {
        "id": "576e7952-954a-4709-8940-c9528983d2b3"
      },
      "source": [
        "# **Plug-and-Play 2: let's also train the denoiser - your turn!**\n",
        "\n",
        "We want to use in PGD a network specifically trained to denoise images from the our dataset.\n",
        "To do so,\n",
        "1.   we create another dataset $\\{\\tilde{u}_i,u_i\\}_{i=1}^{N_{den}}$ such that $\\tilde{u}_i = u_i + \\tilde{\\epsilon}_i$, where $\\tilde{\\epsilon}_i \\sim \\mathcal{N}(0,\\sigma^2 I)$\n",
        "2.   we define a denoiser $D_\\theta$ as a CNN as above and train it, namely we choose its parameters $\\theta$ so to minimize $$L_{den}(\\theta) = \\frac{1}{N_{den}} \\sum_{i=1}^{N_{den}} \\| D_\\theta(\\tilde{u}_i) - u_i\\|^2$$\n",
        "3.  we use the trained denoiser $ D_{\\theta,\\sigma}$ in place of the prox of a regularization functional in any PnP scheme - e.g. PnP-PGD\n",
        "\n",
        "Follow the steps below to implement this strategy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "67bb755b-7dd0-44d5-9373-70fefc139807",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "67bb755b-7dd0-44d5-9373-70fefc139807",
        "outputId": "919bf661-7acd-43ea-eaaa-860728e9b19b"
      },
      "outputs": [],
      "source": [
        "# Part 1: dataset\n",
        "\n",
        "denoiser_train = DnCNN(   # Try also UNet\n",
        "    in_channels=1,\n",
        "    out_channels=1,\n",
        "    pretrained=\"download\",\n",
        "    device=device,\n",
        ")\n",
        "\n",
        "# Create a supervised dataset of simulated measurements\n",
        "\n",
        "# Specify the (maximum) size of the train and test sets\n",
        "n_train_max_PnP = (250 if torch.cuda.is_available() else 50)  # number of images used for training\n",
        "n_test_max_PnP = (50 if torch.cuda.is_available() else 10)  # number of images used for testing\n",
        "\n",
        "# Set the path to save the datasets\n",
        "BASE_DIR = Path(\".\")\n",
        "measurement_dir = BASE_DIR / \"dataset\"\n",
        "\n",
        "# Define the physics: COMPLETE!\n",
        "# Check https://deepinv.github.io/deepinv/api/stubs/deepinv.physics.Denoising.html\n",
        "\n",
        "sigma_PnP = 0.05\n",
        "noise_model_PnP = # TODO\n",
        "physics_PnP = # TODO\n",
        "\n",
        "# Generate noisy dataset\n",
        "deepinv_datasets_path_PnP = dinv.datasets.generate_dataset(\n",
        "    train_dataset=Train_dataset,\n",
        "    test_dataset=Test_dataset,\n",
        "    physics=physics_PnP,\n",
        "    device=device,\n",
        "    save_dir=measurement_dir,\n",
        "    train_datapoints=n_train_max_PnP,\n",
        "    test_datapoints=n_test_max_PnP,\n",
        "    num_workers=num_workers,\n",
        "    dataset_filename=\"denoise\",\n",
        ")\n",
        "\n",
        "train_dataset_PnP = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path_PnP, train=True)\n",
        "test_dataset_PnP = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path_PnP, train=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e30d9fed-8366-4a79-ae68-03a5ced73761",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "e30d9fed-8366-4a79-ae68-03a5ced73761",
        "outputId": "6a5aa047-9503-4511-a22c-ec1804f7f649"
      },
      "outputs": [],
      "source": [
        "# Part 2: training\n",
        "\n",
        "data_fidelity = dinv.optim.data_fidelity.L2()\n",
        "\n",
        "# Set the dataloader for the training\n",
        "batch_size=10\n",
        "train_dataloader_PnP = DataLoader(train_dataset_PnP, batch_size=batch_size, num_workers=num_workers, shuffle=False)\n",
        "test_dataloader_PnP = DataLoader(test_dataset_PnP, batch_size=batch_size, num_workers=num_workers, shuffle=False)\n",
        "\n",
        "# Set the training algorithm\n",
        "learning_rate = 1e-3\n",
        "epochs = 20\n",
        "optimizer = torch.optim.Adam(denoiser_train.parameters(), lr=learning_rate)\n",
        "losses = [dinv.loss.SupLoss(metric=dinv.loss.metric.MSE())]\n",
        "\n",
        "trainer_den = dinv.Trainer(\n",
        "    model=denoiser_train,\n",
        "    physics=physics_PnP,\n",
        "    train_dataloader=train_dataloader_PnP,\n",
        "    eval_dataloader=test_dataloader_PnP,\n",
        "    epochs=epochs,\n",
        "    losses=losses,\n",
        "    optimizer=optimizer,\n",
        "    device=device,\n",
        "    verbose=True,\n",
        "    show_progress_bar=False,\n",
        ")\n",
        "\n",
        "# Training\n",
        "modelDenoiser = trainer_den.train()\n",
        "test(\n",
        "    model=denoiser_train,\n",
        "    test_dataloader=test_dataloader_PnP,\n",
        "    physics=physics_PnP,\n",
        "    device=device,\n",
        "    verbose=True,\n",
        "    plot_images = True,\n",
        ")\n",
        "\n",
        "# Training is over: now let us use the trained model\n",
        "modelDenoiser.eval()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "7eaa3ccf-b7cb-4dc3-a534-aa237de428cb",
      "metadata": {
        "id": "7eaa3ccf-b7cb-4dc3-a534-aa237de428cb"
      },
      "source": [
        "**********************************************************\n",
        "**TASK: Use your brand new denoiser! **\n",
        "\n",
        "Use the trained denoiser within a simple PnP-PGD scheme.\n",
        "\n",
        "Hint: since the prior is strong, you can use a small regularization parameter. This is encoded by selecting a large stepsize.\n",
        "\n",
        "\n",
        "**********************************************************"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4d911289-257a-47d2-8eea-c12b22b59400",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 439
        },
        "id": "4d911289-257a-47d2-8eea-c12b22b59400",
        "outputId": "9535b457-9f48-4008-d59c-e5020cee77d8"
      },
      "outputs": [],
      "source": [
        "# Solution: Part 3\n",
        "\n",
        "params_algo = # TO DO # TO DO\n",
        "max_iter = 200\n",
        "early_stop = True\n",
        "\n",
        "prior = PnP(denoiser=modelDenoiser)\n",
        "\n",
        "data_fidelity = # TODO\n",
        "\n",
        "# Instantiate the algorithm class to solve the IP problem: COMPLETE\n",
        "modelPnP_new = # TODO\n",
        "modelPnP_new.eval() # set the model to evaluation mode. We do not require training here\n",
        "\n",
        "# Set the data loader to test the regularizer: COMPLETE\n",
        "# TODO\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "base",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
